Skip to content

Commit d775976

Browse files
authored
Modification to test settings (#80)
* changed setting for testing * changed utils.py --------- Co-authored-by: John Calderon <john.calderon@centml.ai>
1 parent 0b8b2a8 commit d775976

File tree

5 files changed

+145
-51
lines changed

5 files changed

+145
-51
lines changed

test/config.json

Lines changed: 0 additions & 7 deletions
This file was deleted.

test/config_params.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def TestConfig():
2+
config = dict()
3+
config["model_names_from_examples"] = ["resnet", "nanogpt"]
4+
5+
return config

test/test_database.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import random
33
import deepview_profile.db.database as database
44

5-
5+
LOWER_BOUND_RAND_INT = 1
6+
UPPER_BOUND_RAND_INT = 10
67
class MockDatabaseInterface(database.DatabaseInterface):
78
def __del__(self):
89
if os.path.exists("test.sqlite"):
@@ -37,30 +38,45 @@ def test_invalid_entry_wrong_types(self):
3738
)
3839

3940
def test_adding_valid_entry(self):
40-
params = ["entry_point", random.random(), random.random(), random.randint()]
41+
params = [
42+
"entry_point",
43+
random.random(),
44+
random.random(),
45+
random.randint(LOWER_BOUND_RAND_INT, UPPER_BOUND_RAND_INT),
46+
]
4147
self.energy_table_interface.add_entry(params)
4248
query_result = self.test_database.connection.execute(
43-
"SELECT * FROM ENERGY;"
49+
"SELECT * FROM ENERGY ORDER BY ts DESC;"
4450
).fetchone()
4551
# params is passed in by reference so it have the timestamp in it
4652
assert query_result == tuple(params)
4753

4854
# add 10 valid entries and get top 3
4955
def test_get_latest_n_entries_of_entry_point(self):
5056
for _ in range(10):
51-
params = ["entry_point", random.random(), random.random(), random.randint()]
57+
params = [
58+
"entry_point",
59+
random.random(),
60+
random.random(),
61+
random.randint(LOWER_BOUND_RAND_INT, UPPER_BOUND_RAND_INT),
62+
]
5263
self.energy_table_interface.add_entry(params)
5364
for _ in range(20):
5465
params = [
5566
"other_entry_point",
5667
random.random(),
5768
random.random(),
58-
random.randint(),
69+
random.randint(LOWER_BOUND_RAND_INT, UPPER_BOUND_RAND_INT),
5970
]
6071
self.energy_table_interface.add_entry(params)
6172
entries = []
6273
for _ in range(3):
63-
params = ["entry_point", random.random(), random.random(), random.randint()]
74+
params = [
75+
"entry_point",
76+
random.random(),
77+
random.random(),
78+
random.randint(LOWER_BOUND_RAND_INT, UPPER_BOUND_RAND_INT),
79+
]
6480
entries.insert(0, params)
6581
self.energy_table_interface.add_entry(params)
6682
latest_n_entries = (
@@ -69,4 +85,4 @@ def test_get_latest_n_entries_of_entry_point(self):
6985
)
7086
)
7187
entries = [tuple(entry) for entry in entries]
72-
assert entries == latest_n_entries
88+
assert entries == latest_n_entries

test/test_driver.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,81 @@
1-
21
import pytest
3-
import json
4-
from utils import SkylineSession, BackendContext
2+
import pickle
3+
from utils import DeepviewSession, BackendContext
4+
from google.protobuf.json_format import MessageToDict
5+
from config_params import TestConfig
6+
import os
7+
8+
REPS = 2
9+
NUM_EXPECTED_MESSAGES = 6
10+
11+
12+
def get_config_name():
13+
import pkg_resources
514

6-
with open("config.json", "r") as fp:
7-
config = json.load(fp)
15+
package_versions = {p.key: p.version for p in pkg_resources.working_set}
16+
return package_versions
17+
18+
19+
config = TestConfig()
820

921
tests = list()
10-
for entry_point in config["entry_points"]:
11-
tests.append((config["skyline_bin"], entry_point))
22+
for model_name in config["model_names_from_examples"]:
23+
dir_path = os.path.join(
24+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
25+
"examples",
26+
model_name,
27+
)
28+
tests.append((model_name, dir_path))
29+
1230

13-
@pytest.mark.parametrize("skyline_bin,entry_point", tests)
14-
def test_entry_point(skyline_bin, entry_point):
31+
@pytest.mark.parametrize("test_name, entry_point", tests)
32+
def test_entry_point(test_name, entry_point):
1533
print(f"Testing {entry_point}")
16-
context = BackendContext(skyline_bin, entry_point)
34+
35+
# create new folder
36+
folder = (
37+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/tests_results"
38+
)
39+
os.makedirs(folder, exist_ok=True)
40+
41+
stdout_fd = open(os.path.join(folder, f"{test_name}_interactive_output.log"), "w")
42+
stderr_fd = open(os.path.join(folder, f"{test_name}_interactive_w_debug_output.log"), "w")
43+
context = BackendContext(entry_point, stdout_fd=stdout_fd, stderr_fd=stderr_fd)
1744
context.spawn_process()
1845

19-
sess = SkylineSession()
20-
while context.state == 0:
21-
pass
22-
sess.connect("localhost", 60120)
23-
sess.send_initialize_request()
24-
sess.send_analysis_request()
25-
while len(sess.received_messages) < 4:
26-
pass
46+
analysis_messages = list()
2747

28-
sess.cleanup()
29-
context.terminate()
48+
for reps in range(REPS):
49+
sess = DeepviewSession()
50+
while context.state == 0:
51+
pass
52+
sess.connect("localhost", 60120)
53+
sess.send_initialize_request(entry_point)
54+
sess.send_analysis_request()
55+
while (
56+
context.alive()
57+
and sess.alive()
58+
and len(sess.received_messages) < NUM_EXPECTED_MESSAGES
59+
):
60+
pass
3061

31-
assert(len(sess.received_messages) == 4)
62+
sess.cleanup()
63+
analysis_messages.extend(sess.received_messages)
64+
65+
assert len(sess.received_messages) == NUM_EXPECTED_MESSAGES, (
66+
f"Run {reps}: Expected to receive {NUM_EXPECTED_MESSAGES} got "
67+
f"{len(sess.received_messages)} (did the process terminate prematurely?)"
68+
)
69+
70+
context.terminate()
71+
# create folder to store files
72+
# flush contents to files
73+
with open(os.path.join(folder, f"{test_name}_analysis.pkl"), "wb") as fp:
74+
pickle.dump(list(map(MessageToDict, analysis_messages)), fp)
75+
# write package versions
76+
package_dict = get_config_name()
77+
with open(os.path.join(folder, "package-list.txt"), "w") as f:
78+
for k, v in package_dict.items():
79+
f.write(f"{k}={v}\n")
80+
stdout_fd.close()
81+
stderr_fd.close()

test/utils.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import select
23
import socket
34
import struct
45
import subprocess
@@ -17,27 +18,47 @@ def stream_monitor(stream, callback=None):
1718

1819
def socket_monitor(socket, callback=None):
1920
try:
20-
while True:
21-
msg_len = struct.unpack(">I", socket.recv(4))[0]
22-
msg = socket.recv(msg_len)
23-
print(f"Received message of length {msg_len}")
24-
callback(msg)
25-
except OSError:
21+
while socket:
22+
ret = select.select([socket], [], [], 1)
23+
if len(ret[0]):
24+
msg_len = struct.unpack(">I", socket.recv(4))[0]
25+
buf = None
26+
while socket and msg_len:
27+
msg = socket.recv(msg_len)
28+
if len(msg) == 0: break
29+
buf = buf + msg if buf else msg
30+
msg_len -= len(msg)
31+
32+
callback(buf)
33+
except (ValueError, OSError):
34+
# terminate whenever either recv() or select() fails.
35+
# - if the socket is closed, then recv() will raise OSError as it attempts to read
36+
# from a closed file descriptor.
37+
# - if the socket is closed, then select() will receive an invalid (i.e. negative) file descriptor
38+
# and will raise a ValueError
2639
print(f"Closing listener for socket {socket}")
2740

2841

2942
class BackendContext:
30-
def __init__(self, skyline_bin, entry_point):
43+
def __init__(self, entry_point, stdout_fd=None, stderr_fd=None):
3144
self.process = None
32-
self.skyline_bin = skyline_bin
3345
self.entry_point = entry_point
3446
self.state = 0
47+
self.stdout_fd = stdout_fd
48+
self.stderr_fd = stderr_fd
3549

3650
def on_message_stdout(self, message):
37-
message = message.decode("ascii").rstrip()
51+
# message = message.decode("ascii").rstrip()
52+
message = message.decode("ascii")
53+
if self.stdout_fd:
54+
self.stdout_fd.write(message)
3855

3956
def on_message_stderr(self, message):
40-
message = message.decode("ascii").rstrip()
57+
message = message.decode("ascii")
58+
if self.stderr_fd:
59+
self.stderr_fd.write(message)
60+
61+
message = message.rstrip()
4162
print("stderr", message)
4263
if "DeepView interactive profiling session started!" in message:
4364
self.state = 1
@@ -46,8 +67,7 @@ def spawn_process(self):
4667
# DeepView expects the entry_point filename to be relative
4768
working_dir = os.path.dirname(self.entry_point)
4869
entry_filename = os.path.basename(self.entry_point)
49-
launch_command = [self.skyline_bin, "interactive", entry_filename]
50-
70+
launch_command = ["python", "-m", "deepview_profile", "interactive", "--debug"]
5171
# Launch backend + listener threads for stdout and stderr
5272
self.process = subprocess.Popen(
5373
launch_command,
@@ -64,6 +84,9 @@ def spawn_process(self):
6484
)
6585
self.stderr_thread.start()
6686

87+
def alive(self):
88+
return self.process and self.process.poll() is None
89+
6790
def join(self):
6891
self.process.wait()
6992

@@ -73,7 +96,7 @@ def terminate(self):
7396
self.stderr_thread.join()
7497

7598

76-
class SkylineSession:
99+
class DeepviewSession:
77100
def __init__(self):
78101
self.seq_num = 0
79102
self.received_messages = []
@@ -103,9 +126,11 @@ def send_message(self, message):
103126
self.socket.sendall(length_buffer)
104127
self.socket.sendall(buf)
105128

106-
def send_initialize_request(self):
129+
def send_initialize_request(self, project_root):
107130
request = innpv_pb2.InitializeRequest()
108131
request.protocol_version = 5
132+
request.project_root = project_root
133+
request.entry_point = "entry_point.py"
109134
self.send_message(request)
110135

111136
def send_analysis_request(self):
@@ -116,12 +141,17 @@ def send_analysis_request(self):
116141
def handle_message(self, message):
117142
from_server = innpv_pb2.FromServer()
118143
from_server.ParseFromString(message)
119-
# print("From Server:")
120-
# print(from_server)
144+
print("From Server:")
145+
print(from_server)
121146

122147
self.received_messages.append(from_server)
123148
print(f"new message. total: {len(self.received_messages)}")
124149

150+
def alive(self):
151+
# makes sure that the last message is not error
152+
return (not self.received_messages) or \
153+
not self.received_messages[-1].HasField("analysis_error")
154+
125155
def cleanup(self):
126156
# Closing the socket should cause the listener thread to die
127157
self.socket.close()

0 commit comments

Comments
 (0)