-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_torchvision_resnext_widget.py
107 lines (85 loc) · 3.96 KB
/
infer_torchvision_resnext_widget.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from ikomia import utils, core, dataprocess
from ikomia.utils import pyqtutils, qtconversion
from infer_torchvision_resnext.infer_torchvision_resnext_process import ResnextParam
# PyQt GUI framework
from PyQt5.QtWidgets import *
# --------------------
# - Class which implements widget associated with the process
# - Inherits core.CProtocolTaskWidget from Ikomia API
# --------------------
class ResnextWidget(core.CWorkflowTaskWidget):
def __init__(self, param, parent):
core.CWorkflowTaskWidget.__init__(self, parent)
if param is None:
self.parameters = ResnextParam()
else:
self.parameters = param
# Create layout : QGridLayout by default
self.grid_layout = QGridLayout()
self.combo_model = pyqtutils.append_combo(self.grid_layout, "Model name")
self.combo_model.addItem("resnext50")
self.combo_model.addItem("resnet101")
self.combo_model.setCurrentIndex(self._get_model_name_index())
self.combo_dataset = pyqtutils.append_combo(self.grid_layout, "Trained on")
self.combo_dataset.addItem("ImageNet")
self.combo_dataset.addItem("Custom")
self.combo_dataset.setCurrentIndex(self._get_dataset_index())
self.combo_dataset.currentIndexChanged.connect(self.on_combo_dataset_changed)
self.spin_size = pyqtutils.append_spin(self.grid_layout, label="Input size", value=self.parameters.input_size)
self.browse_model = pyqtutils.append_browse_file(self.grid_layout, "Model path", self.parameters.model_path)
self.browse_classes = pyqtutils.append_browse_file(self.grid_layout, "Classes path", self.parameters.class_file)
if self.parameters.dataset == "ImageNet":
self.browse_model.set_path("")
self.browse_model.setEnabled(False)
self.browse_classes.setEnabled(False)
# PyQt -> Qt wrapping
layout_ptr = qtconversion.PyQtToQt(self.grid_layout)
# Set widget layout
self.set_layout(layout_ptr)
def _get_model_name_index(self):
if self.parameters.model_name == "resnext50":
return 0
elif self.parameters.model_name == "resnext101":
return 1
else:
return 0
def _get_dataset_index(self):
if self.parameters.dataset == "ImageNet":
return 0
else:
return 1
def on_combo_dataset_changed(self, index):
if self.combo_dataset.itemText(index) == "ImageNet":
self.browse_model.set_path("")
self.browse_model.setEnabled(False)
self.browse_classes.set_path(os.path.dirname(os.path.realpath(__file__)) + "/models/imagenet_classes.txt")
self.browse_classes.setEnabled(False)
else:
self.browse_model.clear()
self.browse_model.setEnabled(True)
self.browse_classes.clear()
self.browse_classes.setEnabled(True)
def on_apply(self):
# Apply button clicked slot
# Get parameters from widget
self.parameters.update = True
self.parameters.model_name = self.combo_model.currentText()
self.parameters.dataset = self.combo_dataset.currentText()
self.parameters.input_size = self.spin_size.value()
self.parameters.model_path = self.browse_model.path
self.parameters.class_file = self.browse_classes.path
# Send signal to launch the process
self.emit_apply(self.parameters)
# --------------------
# - Factory class to build process widget object
# - Inherits dataprocess.CWidgetFactory from Ikomia API
# --------------------
class ResnextWidgetFactory(dataprocess.CWidgetFactory):
def __init__(self):
dataprocess.CWidgetFactory.__init__(self)
# Set the name of the process -> it must be the same as the one declared in the process factory class
self.name = "infer_torchvision_resnext"
def create(self, param):
# Create widget object
return ResnextWidget(param, None)