Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[run]
branch = True

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:

.*# pragma: no cover
.*# pragma: no branch

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ unit-tests:
coverage run -m unittest -v

check-test-coverage:
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/eval_accuracy_longcontext.py,benchmarks/math_utils.py" --fail-under=96
coverage report -m --omit="jetstream/tests/*,jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/eval_accuracy_longcontext.py,benchmarks/math_utils.py,benchmarks/tests/*" --fail-under=90
41 changes: 29 additions & 12 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import os
import signal
import sys
import threading
import time
import traceback
Expand All @@ -41,6 +42,21 @@

_HOST = "[::]"

# Create seperate logger to log all INFO message for this module. These show
# stages of server startup and inform user if server is ready to take requests.
# The default logger created in orchestrator.py only logs WARNINGs and above
logger = logging.getLogger(__name__)
logger.propagate = False
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

info_handler = logging.StreamHandler(sys.stdout)
info_handler.setLevel(logging.INFO)
info_handler.setFormatter(formatter)
logger.addHandler(info_handler)


class JetStreamServer:
"""JetStream grpc server."""
Expand Down Expand Up @@ -120,7 +136,7 @@ def create_driver(
prefill_params = [pe.load_params() for pe in engines.prefill_engines]
generate_params = [ge.load_params() for ge in engines.generate_engines]
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
logging.info("Loaded all weights.")
logger.info("Loaded all weights.")
if metrics_collector:
metrics_collector.get_model_load_time_metric().set(
time.time() - model_load_start_time
Expand All @@ -135,13 +151,13 @@ def create_driver(
generate_params = generate_params + shared_params

if prefill_engines is None:
prefill_engines = []
prefill_engines = [] # pragma: no branch
if generate_engines is None:
generate_engines = []
generate_engines = [] # pragma: no branch
if prefill_params is None:
prefill_params = []
prefill_params = [] # pragma: no branch
if generate_params is None:
generate_params = []
generate_params = [] # pragma: no branch

if enable_model_warmup:
prefill_engines = [engine_api.JetStreamEngine(pe) for pe in prefill_engines]
Expand Down Expand Up @@ -215,19 +231,19 @@ def run(
del lora_input_adapters_path

server_start_time = time.time()
logging.info("Kicking off gRPC server.")
logger.info("Kicking off gRPC server.")
# Setup Prometheus server
metrics_collector: JetstreamMetricsCollector = None
if metrics_server_config and metrics_server_config.port:
logging.info(
logger.info(
"Starting Prometheus server on port %d", metrics_server_config.port
)
start_http_server(metrics_server_config.port)
metrics_collector = JetstreamMetricsCollector(
model_name=metrics_server_config.model_name
)
else:
logging.info(
logger.info(
"Not starting Prometheus server: --prometheus_port flag not set"
)

Expand Down Expand Up @@ -256,7 +272,7 @@ def run(
gc.set_threshold(allocs, gen1, gen2)
print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2)

logging.info("Starting server on port %d with %d threads", port, threads)
logger.info("Starting server on port %d with %d threads", port, threads)
jetstream_server.start()

if metrics_collector:
Expand All @@ -266,10 +282,10 @@ def run(

# Setup Jax Profiler
if enable_jax_profiler:
logging.info("Starting JAX profiler server on port %s", jax_profiler_port)
logger.info("Starting JAX profiler server on port %s", jax_profiler_port)
jax.profiler.start_server(jax_profiler_port)
else:
logging.info("Not starting JAX profiler server: %s", enable_jax_profiler)
logger.info("Not starting JAX profiler server: %s", enable_jax_profiler)

# Start profiling server by default for proxy backend.
if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms:
Expand All @@ -279,6 +295,7 @@ def run(
target=proxy_util.start_profiling_server, args=(jax_profiler_port,)
)
thread.run()
logger.info("Server up and ready to process requests on port %s", port)

return jetstream_server

Expand All @@ -287,5 +304,5 @@ def get_devices() -> Any:
"""Gets devices."""
# TODO: Add more logs for the devices.
devices = jax.devices()
logging.info("Using devices: %d", len(devices))
logger.info("Using devices: %d", len(devices))
return devices