-
Notifications
You must be signed in to change notification settings - Fork 0
/
LSTM APP with other blinking features alphaRGB.py
392 lines (353 loc) · 20.2 KB
/
LSTM APP with other blinking features alphaRGB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import tkinter as tk
from tkinter import filedialog
import cv2
import os
import numpy as np
import pandas as pd
import torch
from LSTM_Pressure_RGB_model import LSTMModel
from Accessory_func import histogram_grayworld_whitebalance, large_small_diff, new_smallROI, lstm_input_prep, RollingBuffer
# loading the pressure RGB neural network model
input_size = 3
hidden_size = 128
num_layers = 3
output_size = 1
Pressure_RGB_model = LSTMModel(input_size, hidden_size, num_layers, output_size)
#current_dir = os.getcwd()
current_file_directory = os.path.dirname(os.path.abspath(__file__))
##### Loading LSTM Model Pretrained Weights #####
#state_dict = torch.load(f'{current_file_directory}/P_RGBLSTM_alpha_e200.pt')
state_dict = torch.load('/Users/chenshu/Documents/Research/Terasaki Research/Mechenochromic (Zhu)/Model Log/LSTM+3FC/W alpha 2.0 30lookback/P_RGBLSTM_alpha2.0_e20.pt')
Pressure_RGB_model.load_state_dict(state_dict)
Pressure_RGB_model.eval()
# points = [] # storing the vertices of the ROI (global)
# LSTM_data = torch.tensor(np.zeros(shape = (9, 3)), dtype = torch.float32)
alpha = {'red': 0.19472346,
'green': 0.82113903,
'blue': 0.6681795} # for alpha blending correction
def upload_video():
file_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4 *.mov *.avi")])
points = []
iris_center = []
background_points = []
#LSTM_data = torch.tensor(np.zeros(shape = (99, 3)), dtype = torch.float32)
LSTM_data = torch.tensor(np.full((99,3), 255), dtype = torch.float32)
alpha_tensor = torch.tensor([alpha['red'], alpha['green'], alpha['blue']], dtype = torch.float32)
if file_path:
video_cap = cv2.VideoCapture(file_path)
# store blink (bool) information for blink tracking (e.g. blinking rate per minute, blink interval)
fps = video_cap.get(cv2.CAP_PROP_FPS)
# set the number of element to store in the blink_rate_storage to the number of frames per minute
# then use the number of boolean values (np.sum(bool_blink)/len(bool_blink)) to determine blink rate
blink_time_store = RollingBuffer(int(fps*60)) # storing the time stamp of each frame
blink_status_store = RollingBuffer(int(fps*60)) # storing the status (boolean open/close) of each frame
ret, first_frame = video_cap.read()
if ret:
large_roi = cv2.selectROI(first_frame)
x, y, w, h = large_roi
cv2.rectangle(first_frame, (int(large_roi[0]), int(large_roi[1])),
(int(large_roi[0] + large_roi[2]), int(large_roi[1] + large_roi[3])),
(0, 255, 0), 2)
# Initialize the KCF tracker
tracker = cv2.TrackerKCF_create()
tracker.init(first_frame, (x, y, w, h))
# for ROI selection
ROIselect_frame = first_frame.copy()
#ROIselect_frame = cv2.cvtColor(ROIselect_frame, cv2.COLOR_BGR2RGB)
def select_point(event, x, y, flags, param):
#global points
if event == cv2.EVENT_LBUTTONDOWN:
cv2.circle(ROIselect_frame, (x, y), 2, (0, 0, 255), -1)
cv2.imshow('Select ROI on lens', ROIselect_frame)
points.append([x, y])
elif event == cv2.EVENT_RBUTTONDOWN:
cv2.circle(ROIselect_frame, (x, y), 2, (0, 255, 0), -1)
cv2.imshow('Select iris center', ROIselect_frame)
iris_center.append([x, y])
cv2.namedWindow('Select ROI on lens')
cv2.setMouseCallback('Select ROI on lens', select_point)
while True:
cv2.imshow('Select ROI on lens', ROIselect_frame)
if cv2.waitKey(0) & 0xFF == 13: # return key
break
pts = np.array(points, np.int32)
diff = large_small_diff(pts, large_roi)
iris_center_pt = np.array(iris_center, np.int32)
background_points = np.array([2*iris_center_pt - pt for pt in pts])
diff_back = large_small_diff(background_points, large_roi)
print(iris_center_pt)
print(pts)
print(background_points)
#print(ROIselect_frame.shape)
# reference frame value
mask = np.zeros(ROIselect_frame.shape[:2], dtype = np.uint8)
background_mask = np.zeros(ROIselect_frame.shape[:2], dtype = np.uint8)
cv2.fillPoly(mask, [pts.reshape((-1, 1, 2))], (255, 255, 255))
cv2.fillPoly(background_mask, [background_points.reshape((-1, 1, 2))], (255, 255, 255))
reference_RGB = []
background_RGB = []
for i in range(3):
std_frame = histogram_grayworld_whitebalance(ROIselect_frame)
channel_values = std_frame[:, :, i][mask == 255]
background_values = std_frame[:, :, i][background_mask == 255]
reference_RGB.append(np.mean(channel_values))
background_RGB.append(np.mean(background_values))
reference_RGB = torch.tensor(np.array(reference_RGB), dtype = torch.float32)
background_reference_RGB = torch.tensor(np.array(background_RGB), dtype = torch.float32)
reference_foreground_RGB = (reference_RGB - (1-alpha_tensor)*background_reference_RGB)/alpha_tensor
# LSTM data prep
LSTM_data = torch.cat((LSTM_data, reference_RGB.unsqueeze(0)), dim = 0)
# alpha blending correction
print(background_RGB)
## realtime ROI tracking and pressure prediction
while True:
ret, frame = video_cap.read()
if not ret: # quit algorithm when video reaches the end
print("End of video")
break
tracking_success, roi_coords = tracker.update(frame)
blink_status_store.add(not tracking_success)
# blinking rate per minute
blink_rate = np.sum(blink_status_store.get())/len(blink_status_store.get())
"""
1. create an empty (some easily manipulable datatype) to store the sequential time series data
2. create a helper to dump the storage with specified length of memory (e.g. dumping the series data storage after 10 times steps)
3. change input type to take into account of n timesteps of RGB value and use the model
"""
if tracking_success:
# Convert ROI coordinates to integers
roi_coords = tuple(map(int, roi_coords))
# need to insert new ROI coordinate with tracker information about the lens
# current coordinates are stored in array form (nrow, ncol)
# loop here!
for row in range(pts.shape[0]):
pts[row, :] = new_smallROI(diff[row, :], roi_coords)
# the background (substitute background also moves with the tracker)
background_points[row, :] = new_smallROI(diff_back[row, :], roi_coords)
#pts = pts.reshape((-1, 1, 2))
mask = np.zeros(frame.shape[:2], dtype = np.uint8)
background_mask = np.zeros(frame.shape[:2], dtype = np.uint8)
cv2.fillPoly(mask, [pts.reshape((-1, 1, 2))], (255, 255, 255))
cv2.fillPoly(background_mask, [background_points.reshape((-1, 1, 2))], (255, 255, 255))
# manipulate the real-time frames
RGB_mean = []
background_RGB_mean = []
for i in range(3):
std_frame = histogram_grayworld_whitebalance(frame)
channel_values = std_frame[:, :, i][mask == 255]
background_values = std_frame[:, :, i][background_mask == 255]
RGB_mean.append(np.mean(channel_values))
background_RGB_mean.append(np.mean(background_values))
blue, green, red = RGB_mean
blue_back, green_back, red_back = background_RGB_mean
rgb_tensor = torch.tensor([red, green, blue], dtype = torch.float32)
background_tensor = torch.tensor([red_back, green_back, blue_back], dtype = torch.float32)
# normalized RGB value to the first/selected frame
# assuming the ratio between the difference of RGB and the reference frame is the predictor
rgb_tensor = (rgb_tensor - (1-alpha_tensor)*background_tensor)/(alpha_tensor) - reference_foreground_RGB # alpha corrected
LSTM_data = torch.cat((LSTM_data, rgb_tensor.unsqueeze(0)), axis = 0)
LSTM_data = LSTM_data[-100:]
LSTM_input = lstm_input_prep(LSTM_data)
##### blinking features #####
# eyelid pressure classification
with torch.no_grad():
pressure_pred = Pressure_RGB_model(LSTM_input).item()
text = (f"Red channel is {red}, back {red_back}",
f"Green channel is {green}, back {green_back}",
f"Blue channel is {blue}, back {blue_back}",
f"Predicted Pressure is {pressure_pred}",
f"The corrected RGB is {rgb_tensor}",
#f"Blink rate of past minute {blink_rate}",
f"There are {blink_rate * 60} blinks in the past minute",
"Press q to end session")
##### update tracked regions & display real-time information #####
# draw the lens ROI
cv2.rectangle(frame, (int(roi_coords[0]), int(roi_coords[1])),
(int(roi_coords[0] + roi_coords[2]), int(roi_coords[1] + roi_coords[3])),
(0, 255, 0), 2) # (0, 255, 0) is the color (green), and 2 is the thickness
# draw the color changing ROI
for vertex_id in range(pts.reshape((-1, 1, 2)).shape[0]):
vertex = tuple(pts.reshape((-1, 1, 2))[vertex_id, :, :][0])
cv2.circle(frame, vertex, 2, (0, 0, 255), -1)
# reflected background
for background_vertex_id in range(background_points.reshape((-1, 1, 2)).shape[0]):
background_vertex = tuple(background_points.reshape((-1, 1, 2))[background_vertex_id, :, :][0])
cv2.circle(frame, background_vertex, 2, (0, 0, 255), -1)
cv2.polylines(frame, [pts], isClosed = True, color = (0, 0, 255), thickness = 2)
cv2.polylines(frame, [background_points], isClosed = True, color = (255, 0, 0), thickness = 2)
y0 = 50
dy = 40
for i, line in enumerate(text):
y = y0 + i*dy
cv2.putText(frame, line, (50, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, lineType = cv2.LINE_AA)
cv2.imshow('Image', frame)
else: # when the lens was not detected
error_frame = frame.copy()
cv2.putText(error_frame, "Tracking failed!", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.imshow('Image', error_frame)
#print(blink_status_store.get())
if cv2.waitKey(1) == ord('q'):
break
video_cap.release()
cv2.destroyAllWindows()
def webcam_capture():
video_cap = cv2.VideoCapture(0)
LSTM_data = torch.tensor(np.zeros(shape = (9, 3)), dtype = torch.float32)
points = []
# keep showing the video until find good frame to use
exit_while = False
while True:
# store blink (bool) information for blink tracking (e.g. blinking rate per minute, blink interval)
fps = video_cap.get(cv2.CAP_PROP_FPS)
# set the number of element to store in the blink_rate_storage to the number of frames per minute
# then use the number of boolean values (np.sum(bool_blink)/len(bool_blink)) to determine blink rate
blink_time_store = RollingBuffer(int(fps*60)) # storing the time stamp of each frame
blink_status_store = RollingBuffer(int(fps*60)) # storing the status (boolean open/close) of each frame
ret, first_frame = video_cap.read()
selection_frame = first_frame.copy()
guide_selectstartframe = 'Press the "y" key on keyboard to select the first frame to get started'
cv2.putText(selection_frame, guide_selectstartframe, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, lineType = cv2.LINE_AA)
cv2.imshow('Select Starting Frame', selection_frame)
# if press the "y" key will capture that corresponding frame
if cv2.waitKey(1) & 0xFF == 121:
## selecting region of the video frame for RGB value extraction
if ret:
# guide_selectROI = ("1. Use mouse to drag a square around the lens of interest for tracking.",
# "2. Press the enter key on keyboard to confirm selection.",
# "3. Press the 'c' key on keyboard to exit the program.")
# Select the ROI
large_roi = cv2.selectROI(first_frame)
x, y, w, h = large_roi
cv2.rectangle(first_frame, (int(large_roi[0]), int(large_roi[1])),
(int(large_roi[0] + large_roi[2]), int(large_roi[1] + large_roi[3])),
(0, 255, 0), 2) # Here, (0, 255, 0) is the color (green), and 2 is the thickness
# for i, line in enumerate(guide_selectROI):
# y = 50 + i*40
# cv2.putText(first_frame, line, (50, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, lineType = cv2.LINE_AA)
# Initialize the KCF tracker
tracker = cv2.TrackerKCF_create()
tracker.init(first_frame, (x, y, w, h))
#points = []
ROIselect_frame = first_frame.copy()
def select_point(event, x, y, flags, param):
#global points
if event == cv2.EVENT_LBUTTONDOWN:
cv2.circle(ROIselect_frame, (x, y), 2, (0, 0, 255), -1)
cv2.imshow('Select ROI on lens', ROIselect_frame)
points.append([x, y])
cv2.namedWindow('Select ROI on lens')
cv2.setMouseCallback('Select ROI on lens', select_point)
while True:
cv2.imshow('Select ROI on lens', ROIselect_frame)
if cv2.waitKey(0) & 0xFF == 13: # return key
exit_while = True
break
# Convert points to numpy array
pts = np.array(points, np.int32)
diff = large_small_diff(pts, large_roi)
"""
add reference (1st frame) frame RGB here
"""
mask = np.zeros(ROIselect_frame.shape[:2], dtype = np.uint8)
cv2.fillPoly(mask, [pts.reshape((-1, 1, 2))], (255, 255, 255))
reference_RGB = []
for i in range(3):
# standardized RGB value in frame
#applying RGB standardization WILL slow down the tracking speed
std_frame = histogram_grayworld_whitebalance(ROIselect_frame)
channel_values = std_frame[:, :, i][mask == 255]
#channel_values = frame[:, :, i][mask == 255]
reference_RGB.append(np.mean(channel_values))
reference_RGB = torch.tensor(reference_RGB, dtype = torch.float32)
# LSTM data prep
LSTM_data = torch.cat((LSTM_data, reference_RGB.unsqueeze(0)), dim = 0)
if exit_while == True:
break
## realtime ROI tracking and pressure prediction
while True:
_, frame = video_cap.read()
tracking_success, roi_coords = tracker.update(frame)
blink_status_store.add(not tracking_success)
# blinking rate per minute
blink_rate = np.sum(blink_status_store.get())/len(blink_status_store.get())
"""
1. create an empty (some easily manipulable datatype) to store the sequential time series data
2. create a helper to dump the storage with specified length of memory (e.g. dumping the series data storage after 10 times steps)
3. change input type to take into account of n timesteps of RGB value and use the model
"""
if tracking_success:
roi_coords = tuple(map(int, roi_coords))
for row in range(pts.shape[0]):
pts[row, :] = new_smallROI(diff[row, :], roi_coords)
mask = np.zeros(frame.shape[:2], dtype = np.uint8)
cv2.fillPoly(mask, [pts.reshape((-1, 1, 2))], (255, 255, 255))
# manipulate the real-time frames
RGB_mean = []
for i in range(3):
# standardized RGB value in frame
#applying RGB standardization WILL slow down the tracking speed
std_frame = histogram_grayworld_whitebalance(frame)
channel_values = std_frame[:, :, i][mask == 255]
#channel_values = frame[:, :, i][mask == 255]
RGB_mean.append(np.mean(channel_values))
blue, green, red = RGB_mean
rgb_tensor = torch.tensor([red, green, blue], dtype = torch.float32)
# relative RGB value to the first/selected frame
ref_rgb_tensor = rgb_tensor - reference_RGB
LSTM_data = torch.cat((LSTM_data, ref_rgb_tensor.unsqueeze(0)), axis = 0)
LSTM_data = LSTM_data[-10:]
LSTM_input = lstm_input_prep(LSTM_data)
with torch.no_grad():
pressure_pred = Pressure_RGB_model(LSTM_input).item()
text = (f"Red channel is {red}",
f"Green channel is {green}",
f"Blue channel is {blue}",
f"Predicted Pressure is {pressure_pred}",
f"There are {blink_rate * 60} blinks in the past minute",
"Press q to exit session")
# draw the lens ROI
cv2.rectangle(frame, (int(roi_coords[0]), int(roi_coords[1])),
(int(roi_coords[0] + roi_coords[2]), int(roi_coords[1] + roi_coords[3])),
(0, 255, 0), 2) # (0, 255, 0) is the color (green), and 2 is the thickness
# draw the color changing ROI
for vertex_id in range(pts.reshape((-1, 1, 2)).shape[0]):
vertex = tuple(pts.reshape((-1, 1, 2))[vertex_id, :, :][0])
cv2.circle(frame, vertex, 2, (0, 0, 255), -1)
cv2.polylines(frame, [pts], isClosed = True, color = (0, 0, 255), thickness = 2)
#cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
y0 = 50
dy = 40
for i, line in enumerate(text):
y = y0 + i*dy
cv2.putText(frame, line, (50, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, lineType = cv2.LINE_AA)
cv2.imshow('Image', frame)
else:
error_frame = frame.copy()
cv2.putText(error_frame, "Tracking failed!", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.imshow('Image', error_frame)
if cv2.waitKey(1) == ord('q'):
break
video_cap.release()
cv2.destroyAllWindows()
root = tk.Tk()
root.title("Real-time DES Diagnositic Assistance (Alpha)")
# Set the window size
window_width = 800
window_height = 425
root.geometry(f"{window_width}x{window_height}")
# Add background image
bg_image = tk.PhotoImage(file=f"{current_file_directory}/Background_App.png") # Change "background.png" to your image file
bg_label = tk.Label(root, image=bg_image)
bg_label.place(relwidth=1, relheight=1)
# Define a custom font for title
custom_font = ("Broadway", 20) # Font family and size
APP_title = tk.Label(root, text = 'Real-time DES Diagnostic Assistance', font = custom_font)
#APP_title.pack(side = 'top', anchor = 'n')
APP_title.place(x = 250, y = 150)
upload_button = tk.Button(root, text="Upload Video", command=upload_video)
upload_button.place(x = 350, y = 200)
webcam_button = tk.Button(root, text="Use Webcam", command=webcam_capture)
#webcam_button.pack(pady=10)
#webcam_button.grid(row = 5, pady=10)
webcam_button.place(x = 350, y = 230)
root.mainloop()