-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy pathtest_accuracy.py
114 lines (91 loc) · 2.88 KB
/
test_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import subprocess
import time
import os
import requests
import sys
import json
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--tp", type=int, required=True, help="Number of GPUs to use.")
parser.add_argument("--model_dir", type=str, required=True, help="Directory of the model.")
return parser.parse_args()
def start_server(tp, model_dir):
cmd = [
"python",
"-m",
"lightllm.server.api_server",
"--tp",
str(tp),
"--model_dir",
model_dir,
"--data_type",
"fp16",
"--mode",
"triton_gqa_flashdecoding",
"--trust_remote_code",
"--tokenizer_mode",
"fast",
"--host",
"0.0.0.0",
"--port",
"8080",
]
process = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
return process
def check_health():
health_url = "http://localhost:8080/health"
try:
r = requests.get(health_url, timeout=2)
return r.status_code == 200
except Exception:
return False
def send_prompts(prompts, output_file):
for prompt in prompts:
while not check_health():
time.sleep(1)
request_data = {
"inputs": prompt,
"parameters": {"max_new_tokens": 1024, "frequency_penalty": 1, "do_sample": False},
"multimodal_params": {},
}
try:
r = requests.post("http://localhost:8080/generate", json=request_data, timeout=10)
response_json = json.loads(r.text)
generated_text = (
response_json["generated_text"][0] if "generated_text" in response_json else "No generated_text."
)
except Exception as e:
generated_text = f"ERROR: {str(e)}"
with open(output_file, "a", encoding="utf-8") as f:
f.write(f"===== prompt: {prompt} =====\n")
f.write(f"{generated_text}\n\n")
print(f"===================Ouput saved in {output_file}===========================")
def main():
# args
args = parse_args()
tp = args.tp
model_dir = args.model_dir
# output_file
output_file = "test_results.txt"
if os.path.exists(output_file):
os.remove(output_file)
# start server
process = start_server(tp, model_dir)
# prompts
prompts = [
"What is the machine learning?",
"1+1等于几",
"What role does attention play in transformer architectures?",
"西红柿炒鸡蛋怎么做?",
"Describe the concept of overfitting and underfitting.",
"CPU和GPU的区别是什么?",
"What is the role of a loss function in machine learning?",
]
send_prompts(prompts, output_file)
# shutdown server
process.terminate()
process.wait()
if __name__ == "__main__":
main()
# python test_accuracy.py --tp 2 --model_dir /xx/xx