Skip to content

Commit

Permalink
Fixes integration test (deepjavalibrary#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored and KexinFeng committed Aug 16, 2023
1 parent 97594a4 commit 734516a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
29 changes: 12 additions & 17 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
from datetime import datetime
from io import BytesIO
import hashlib

logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Build the LLM configs")
Expand Down Expand Up @@ -53,11 +52,10 @@
args = parser.parse_args()


def compute_model_name_hash(model_name):
# This mirrors the Utils.hash implementation from DJL Core
m = hashlib.sha256()
m.update(model_name)
return m.hexdigest()[:40]
def get_model_name():
endpoint = f"http://127.0.0.1:8080/models"
res = requests.get(endpoint).json()
return res["models"][0]["modelName"]


ds_raw_model_spec = {
Expand Down Expand Up @@ -118,21 +116,18 @@ def compute_model_name_hash(model_name):
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 1,
"model_name": compute_model_name_hash(b"nomic-ai/gpt4all-j"),
},
"no-code/databricks/dolly-v2-7b": {
"max_memory_per_gpu": [10.0, 12.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2,
"model_name": compute_model_name_hash(b"databricks/dolly-v2-7b"),
},
"no-code/google/flan-t5-xl": {
"max_memory_per_gpu": [7.0, 7.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2,
"model_name": compute_model_name_hash(b"google/flan-t5-xl")
}
}

Expand Down Expand Up @@ -266,7 +261,8 @@ def compute_model_name_hash(model_name):
}


def check_worker_number(desired, model_name="test"):
def check_worker_number(desired):
model_name = get_model_name()
endpoint = f"http://127.0.0.1:8080/models/{model_name}"
res = requests.get(endpoint).json()
if desired == len(res[0]["models"][0]["workerGroups"]):
Expand All @@ -278,9 +274,9 @@ def check_worker_number(desired, model_name="test"):
f"Worker number does not meet requirements! {res}")


def send_json(data, model_name="test"):
def send_json(data):
headers = {'content-type': 'application/json'}
endpoint = f"http://127.0.0.1:8080/predictions/{model_name}"
endpoint = f"http://127.0.0.1:8080/invocations"
resp = requests.post(endpoint, headers=headers, json=data)

if resp.status_code >= 300:
Expand All @@ -289,12 +285,12 @@ def send_json(data, model_name="test"):
return resp


def send_image_json(img_url, data, model_name="test"):
def send_image_json(img_url, data):
multipart_form_data = {
'data': BytesIO(requests.get(img_url, stream=True).content),
'json': (None, json.dumps(data), 'application/json')
}
endpoint = f"http://127.0.0.1:8080/predictions/{model_name}"
endpoint = f"http://127.0.0.1:8080/invocations"
resp = requests.post(endpoint, files=multipart_form_data)

if resp.status_code >= 300:
Expand Down Expand Up @@ -459,8 +455,7 @@ def test_handler(model, model_spec):
)
spec = model_spec[args.model]
if "worker" in spec:
check_worker_number(spec["worker"],
model_name=spec.get("model_name", "test"))
check_worker_number(spec["worker"])
for i, batch_size in enumerate(spec["batch_size"]):
for seq_length in spec["seq_length"]:
if "t5" in model:
Expand All @@ -470,7 +465,7 @@ def test_handler(model, model_spec):
params = {"max_new_tokens": seq_length}
req["parameters"] = params
logging.info(f"req {req}")
res = send_json(req, model_name=spec.get("model_name", "test"))
res = send_json(req)
if spec.get("stream_output", False):
logging.info(f"res: {res.content}")
result = res.content.decode().split("\n")[:-1]
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/llm/sagemaker-endpoint-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@
}
}

ENGINE_TO_METRIC_CONFIG_ENGINE = {
"Python" : "Accelerate"
}
ENGINE_TO_METRIC_CONFIG_ENGINE = {"Python": "Accelerate"}


def get_sagemaker_session(default_bucket=DEFAULT_BUCKET,
Expand Down

0 comments on commit 734516a

Please sign in to comment.