-
Notifications
You must be signed in to change notification settings - Fork 37
/
ScanConversionInference.py
330 lines (274 loc) · 14.7 KB
/
ScanConversionInference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""
Implements an OpenIGTLink client that receives ultrasound (pyigtl.ImageMessage) and sends prediction/segmentation (pyigtl.ImageMessage).
Transform messages (pyigtl.TransformMessage) are also received and sent to the server, but the device name is changed by replacing Image to Prediction.
This is done to ensure that the prediction is visualized in the same position as the ultrasound image.
Arguments:
model: Path to the torchscript file you intend to use for segmentation. The model must be a torchscript model that takes a single image as input and returns a single image as output.
input device name: This is the device name the client is listening to
output device name: The device name the client outputs to
host: Server's IP the client connects to.
input port: Port used for receiving data from the PLUS server over OpenIGTLink
output port: Port used for sending data to Slicer over OpenIGTLink
"""
import argparse
import cv2
import json
import logging
import numpy as np
import traceback
import sys
import pyigtl
import time
import torch
import yaml
from pathlib import Path
from scipy.ndimage import map_coordinates
from scipy.spatial import Delaunay
ROOT = Path(__file__).parent.resolve()
# Parse command line arguments
def ScanConversionInference():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="Path to torchscript model file.")
parser.add_argument("--scanconversion_config", type=str, help="Path to scan conversion config (.yaml) file. Optional.")
parser.add_argument("--input-device-name", type=str, default="Image_Image")
parser.add_argument("--output-device-name", type=str, default="Prediction")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--input-port", type=int, default=18944)
parser.add_argument("--output-port", type=int, default=18945)
parser.add_argument("--log_file", type=str, default=None, help="Path to log file. Optional.")
try:
args = parser.parse_args()
except SystemExit as err:
traceback.print_exc()
sys.exit(err.code)
if args.log_file:
logging.basicConfig(filename=args.log_file, filemode='w', level=logging.INFO)
else:
logging.basicConfig(level=logging.INFO)
run_client(args)
def run_client(args):
"""
Runs the client in an infinite loop, waiting for messages from the server. Once a message is received,
the message is processed and the inference is sent back to the server as a pyigtl ImageMessage.
"""
input_client = pyigtl.OpenIGTLinkClient(host=args.host, port=args.input_port)
output_server = pyigtl.OpenIGTLinkServer(port=args.output_port)
model = None
# Initialize timer and counters for profiling
start_time = time.perf_counter()
preprocess_counter = 0
preprocess_total_time = 0
inference_counter = 0
inference_total_time = 0
postprocess_counter = 0
postprocess_total_time = 0
total_counter = 0
total_time = 0
image_message_counter = 0
transform_message_counter = 0
# Load pytorch model
logging.info("Loading model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = args.model if Path(args.model).is_absolute() else f'{str(ROOT)}/{args.model}'
extra_files = {"config.json": ""}
model = torch.jit.load(model_path, _extra_files=extra_files).to(device)
config = json.loads(extra_files["config.json"])
input_size = config["shape"][-1]
logging.info("Model loaded")
torch.inference_mode()
# If scan conversion is enabled, compute x_cart, y_cart, vertices, and weights for conversion and interpolation
if args.scanconversion_config:
logging.info("Loading scan conversion config...")
with open(args.scanconversion_config, "r") as f:
scanconversion_config = yaml.safe_load(f)
x_cart, y_cart = scan_conversion_inverse(scanconversion_config)
logging.info("Scan conversion config loaded")
else:
scanconversion_config = None
x_cart = None
y_cart = None
logging.info("Scan conversion config not found")
if x_cart is not None and y_cart is not None:
vertices, weights = scan_interpolation_weights(scanconversion_config)
mask_array = curvilinear_mask(scanconversion_config)
else:
vertices = None
weights = None
mask_array = None
while True:
# Print average inference time
if time.perf_counter() - start_time > 1.0:
logging.info("--------------------------------------------------")
logging.info(f"Image messages received: {image_message_counter}")
logging.info(f"Transform messages received: {transform_message_counter}")
if preprocess_counter > 0:
avg_preprocess_time = round((preprocess_total_time / preprocess_counter) * 1000, 1)
logging.info(f"Average preprocess time: {avg_preprocess_time} ms")
if inference_counter > 0:
avg_inference_time = round((inference_total_time / inference_counter) * 1000, 1)
logging.info(f"Average inference time: {avg_inference_time} ms")
if postprocess_counter > 0:
avg_postprocess_time = round((postprocess_total_time / postprocess_counter) * 1000, 1)
logging.info(f"Average postprocess time: {avg_postprocess_time} ms")
if total_counter > 0:
avg_total_time = round((total_time / total_counter) * 1000, 1)
logging.info(f"Average total time: {avg_total_time} ms")
start_time = time.perf_counter()
preprocess_counter = 0
preprocess_total_time = 0
inference_counter = 0
inference_total_time = 0
postprocess_counter = 0
postprocess_total_time = 0
total_counter = 0
total_time = 0
image_message_counter = 0
transform_message_counter = 0
# Receive messages from server
messages = input_client.get_latest_messages()
for message in messages:
if message.device_name == args.input_device_name: # Image message
image_message_counter += 1
total_start_time = time.perf_counter()
if model is None:
logging.error("Model not loaded. Exiting...")
break
# Resize image to model input size
orig_img_size = message.image.shape
# Preprocess input
preprocess_start_time = time.perf_counter()
image = preprocess_input(message.image, input_size, scanconversion_config, x_cart, y_cart).to(device)
preprocess_total_time += time.perf_counter() - preprocess_start_time
preprocess_counter += 1
# Run inference
inference_start_time = time.perf_counter()
prediction = model(image)
if isinstance(prediction, list):
prediction = prediction[0]
inference_total_time += time.perf_counter() - inference_start_time
inference_counter += 1
# Postprocess prediction
postprocess_start_time = time.perf_counter()
prediction = torch.nn.functional.softmax(prediction, dim=1)
prediction = postprocess_prediction(prediction, orig_img_size, scanconversion_config, vertices, weights, mask_array)
postprocess_total_time += time.perf_counter() - postprocess_start_time
postprocess_counter += 1
image_message = pyigtl.ImageMessage(prediction, device_name=args.output_device_name)
output_server.send_message(image_message, wait=True)
total_time += time.perf_counter() - total_start_time
total_counter += 1
if message.message_type == "TRANSFORM" and "Image" in message.device_name: # Image transform message
transform_message_counter += 1
output_tfm_name = message.device_name.replace("Image", "Prediction")
tfm_message = pyigtl.TransformMessage(message.matrix, device_name=output_tfm_name)
output_server.send_message(tfm_message, wait=True)
def preprocess_input(image, input_size, scanconversion_config=None, x_cart=None, y_cart=None):
if scanconversion_config is not None:
# Scan convert image from curvilinear to linear
num_samples = scanconversion_config["num_samples_along_lines"]
num_lines = scanconversion_config["num_lines"]
converted_image = np.zeros((1, num_lines, num_samples))
converted_image[0, :, :] = map_coordinates(image[0, :, :], [x_cart, y_cart], order=1, mode='constant', cval=0.0)
# Squeeze converted image to remove channel dimension
converted_image = converted_image.squeeze()
else:
converted_image = cv2.resize(image[0, :, :], (input_size, input_size)) / 255 # default is bilinear
converted_image = torch.from_numpy(converted_image).unsqueeze(0).unsqueeze(0).float()
return converted_image
def postprocess_prediction(prediction, original_size, scanconversion_config=None, vertices=None, weights=None, mask_array=None):
if scanconversion_config is not None:
# Scan convert prediction from linear to curvilinear
prediction = prediction.squeeze().detach().cpu().numpy() * 255
# Make sure prediction data type is uint8
# prediction = prediction.astype(np.uint8)[np.newaxis, ...]
prediction = scan_convert(prediction[1], scanconversion_config, vertices, weights)
if mask_array is not None:
prediction = prediction * mask_array
prediction = prediction.astype(np.uint8)[np.newaxis, ...]
else:
prediction = prediction.squeeze().detach().cpu().numpy() * 255
prediction = cv2.resize(prediction[1], (original_size[2], original_size[1]))
prediction = prediction.astype(np.uint8)[np.newaxis, ...]
return prediction
def scan_conversion_inverse(scanconversion_config):
"""
Compute cartesian coordianates for inverse scan conversion.
Mapping from curvilinear image to a rectancular image of scan lines as columns.
The returned cartesian coordinates can be used to map the curvilinear image to a rectangular image using scipy.ndimage.map_coordinates.
Args:
scanconversion_config (dict): Dictionary with scan conversion parameters.
Rerturns:
x_cart (np.ndarray): x coordinates of the cartesian grid.
y_cart (np.ndarray): y coordinates of the cartesian grid.
Example:
>>> x_cart, y_cart = scan_conversion_inverse(scanconversion_config)
>>> scan_converted_image = map_coordinates(ultrasound_data[0, :, :, 0], [x_cart, y_cart], order=3, mode="nearest")
>>> scan_converted_segmentation = map_coordinates(segmentation_data[0, :, :, 0], [x_cart, y_cart], order=0, mode="nearest")
"""
# Create sampling points in polar coordinates
initial_radius = np.deg2rad(scanconversion_config["angle_min_degrees"])
final_radius = np.deg2rad(scanconversion_config["angle_max_degrees"])
radius_start_px = scanconversion_config["radius_start_pixels"]
radius_end_px = scanconversion_config["radius_end_pixels"]
theta, r = np.meshgrid(np.linspace(initial_radius, final_radius, scanconversion_config["num_samples_along_lines"]),
np.linspace(radius_start_px, radius_end_px, scanconversion_config["num_lines"]))
# Convert the polar coordinates to cartesian coordinates
x_cart = r * np.cos(theta) + scanconversion_config["center_coordinate_pixel"][0]
y_cart = r * np.sin(theta) + scanconversion_config["center_coordinate_pixel"][1]
return x_cart, y_cart
def scan_interpolation_weights(scanconversion_config):
image_size = scanconversion_config["curvilinear_image_size"]
x_cart, y_cart = scan_conversion_inverse(scanconversion_config)
triangulation = Delaunay(np.vstack((x_cart.flatten(), y_cart.flatten())).T)
grid_x, grid_y = np.mgrid[0:image_size, 0:image_size]
simplices = triangulation.find_simplex(np.vstack((grid_x.flatten(), grid_y.flatten())).T)
vertices = triangulation.simplices[simplices]
X = triangulation.transform[simplices, :2]
Y = np.vstack((grid_x.flatten(), grid_y.flatten())).T - triangulation.transform[simplices, 2]
b = np.einsum('ijk,ik->ij', X, Y)
weights = np.c_[b, 1 - b.sum(axis=1)]
return vertices, weights
def scan_convert(linear_data, scanconversion_config, vertices, weights):
"""
Scan convert a linear image to a curvilinear image.
Args:
linear_data (np.ndarray): Linear image to be scan converted.
scanconversion_config (dict): Dictionary with scan conversion parameters.
Returns:
scan_converted_image (np.ndarray): Scan converted image.
"""
z = linear_data.flatten()
zi = np.einsum('ij,ij->i', np.take(z, vertices), weights)
image_size = scanconversion_config["curvilinear_image_size"]
return zi.reshape(image_size, image_size)
def curvilinear_mask(scanconversion_config):
"""
Generate a binary mask for the curvilinear image with ones inside the scan lines area and zeros outside.
Args:
scanconversion_config (dict): Dictionary with scan conversion parameters.
Returns:
mask_array (np.ndarray): Binary mask for the curvilinear image with ones inside the scan lines area and zeros outside.
"""
angle1 = 90.0 + (scanconversion_config["angle_min_degrees"])
angle2 = 90.0 + (scanconversion_config["angle_max_degrees"])
center_rows_px = scanconversion_config["center_coordinate_pixel"][0]
center_cols_px = scanconversion_config["center_coordinate_pixel"][1]
radius1 = scanconversion_config["radius_start_pixels"]
radius2 = scanconversion_config["radius_end_pixels"]
image_size = scanconversion_config["curvilinear_image_size"]
mask_array = np.zeros((image_size, image_size), dtype=np.int8)
mask_array = cv2.ellipse(mask_array, (center_cols_px, center_rows_px), (radius2, radius2), 0.0, angle1, angle2, 1, -1)
mask_array = cv2.circle(mask_array, (center_cols_px, center_rows_px), radius1, 0, -1)
# Convert mask_array to uint8
mask_array = mask_array.astype(np.uint8)
# Repaint the borders of the mask to zero to allow erosion from all sides
mask_array[0, :] = 0
mask_array[:, 0] = 0
mask_array[-1, :] = 0
mask_array[:, -1] = 0
# Erode mask by 10 percent of the image size to remove artifacts on the edges
erosion_size = int(0.1 * image_size)
mask_array = cv2.erode(mask_array, np.ones((erosion_size, erosion_size), np.uint8), iterations=1)
return mask_array
if __name__ == "__main__":
ScanConversionInference()