From 734516ae52339407c4216de4a8bbb72e86c66438 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 30 May 2023 15:01:34 -0700 Subject: [PATCH] Fixes integration test (#779) --- tests/integration/llm/client.py | 29 ++++++++----------- .../llm/sagemaker-endpoint-tests.py | 4 +-- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 09acab5ac..b69a5eeb5 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -10,7 +10,6 @@ import numpy as np from datetime import datetime from io import BytesIO -import hashlib logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Build the LLM configs") @@ -53,11 +52,10 @@ args = parser.parse_args() -def compute_model_name_hash(model_name): - # This mirrors the Utils.hash implementation from DJL Core - m = hashlib.sha256() - m.update(model_name) - return m.hexdigest()[:40] +def get_model_name(): + endpoint = f"http://127.0.0.1:8080/models" + res = requests.get(endpoint).json() + return res["models"][0]["modelName"] ds_raw_model_spec = { @@ -118,21 +116,18 @@ def compute_model_name_hash(model_name): "batch_size": [1, 4], "seq_length": [16, 32], "worker": 1, - "model_name": compute_model_name_hash(b"nomic-ai/gpt4all-j"), }, "no-code/databricks/dolly-v2-7b": { "max_memory_per_gpu": [10.0, 12.0], "batch_size": [1, 4], "seq_length": [16, 32], "worker": 2, - "model_name": compute_model_name_hash(b"databricks/dolly-v2-7b"), }, "no-code/google/flan-t5-xl": { "max_memory_per_gpu": [7.0, 7.0], "batch_size": [1, 4], "seq_length": [16, 32], "worker": 2, - "model_name": compute_model_name_hash(b"google/flan-t5-xl") } } @@ -266,7 +261,8 @@ def compute_model_name_hash(model_name): } -def check_worker_number(desired, model_name="test"): +def check_worker_number(desired): + model_name = get_model_name() endpoint = f"http://127.0.0.1:8080/models/{model_name}" res = requests.get(endpoint).json() if desired == len(res[0]["models"][0]["workerGroups"]): @@ -278,9 +274,9 @@ def check_worker_number(desired, model_name="test"): f"Worker number does not meet requirements! {res}") -def send_json(data, model_name="test"): +def send_json(data): headers = {'content-type': 'application/json'} - endpoint = f"http://127.0.0.1:8080/predictions/{model_name}" + endpoint = f"http://127.0.0.1:8080/invocations" resp = requests.post(endpoint, headers=headers, json=data) if resp.status_code >= 300: @@ -289,12 +285,12 @@ def send_json(data, model_name="test"): return resp -def send_image_json(img_url, data, model_name="test"): +def send_image_json(img_url, data): multipart_form_data = { 'data': BytesIO(requests.get(img_url, stream=True).content), 'json': (None, json.dumps(data), 'application/json') } - endpoint = f"http://127.0.0.1:8080/predictions/{model_name}" + endpoint = f"http://127.0.0.1:8080/invocations" resp = requests.post(endpoint, files=multipart_form_data) if resp.status_code >= 300: @@ -459,8 +455,7 @@ def test_handler(model, model_spec): ) spec = model_spec[args.model] if "worker" in spec: - check_worker_number(spec["worker"], - model_name=spec.get("model_name", "test")) + check_worker_number(spec["worker"]) for i, batch_size in enumerate(spec["batch_size"]): for seq_length in spec["seq_length"]: if "t5" in model: @@ -470,7 +465,7 @@ def test_handler(model, model_spec): params = {"max_new_tokens": seq_length} req["parameters"] = params logging.info(f"req {req}") - res = send_json(req, model_name=spec.get("model_name", "test")) + res = send_json(req) if spec.get("stream_output", False): logging.info(f"res: {res.content}") result = res.content.decode().split("\n")[:-1] diff --git a/tests/integration/llm/sagemaker-endpoint-tests.py b/tests/integration/llm/sagemaker-endpoint-tests.py index ad92e9f8e..f8ac282a0 100644 --- a/tests/integration/llm/sagemaker-endpoint-tests.py +++ b/tests/integration/llm/sagemaker-endpoint-tests.py @@ -129,9 +129,7 @@ } } -ENGINE_TO_METRIC_CONFIG_ENGINE = { - "Python" : "Accelerate" -} +ENGINE_TO_METRIC_CONFIG_ENGINE = {"Python": "Accelerate"} def get_sagemaker_session(default_bucket=DEFAULT_BUCKET,