-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_torchvision_resnet_process.py
178 lines (148 loc) · 6.74 KB
/
infer_torchvision_resnet_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from ikomia import core, dataprocess
from ikomia.dnn.torch import models
import os
import copy
import cv2
import torch
import torchvision.transforms as transforms
# --------------------
# - Class to handle the process parameters
# - Inherits core.CProtocolTaskParam from Ikomia API
# --------------------
class ResnetParam(core.CWorkflowTaskParam):
def __init__(self):
core.CWorkflowTaskParam.__init__(self)
# Place default value initialization here
self.model_name = 'resnet18'
self.dataset = 'ImageNet'
self.input_size = 224
self.model_weight_file = ''
self.class_file = os.path.dirname(os.path.realpath(__file__)) + "/models/imagenet_classes.txt"
self.update = False
def set_values(self, params):
# Set parameters values from Ikomia application
# Parameters values are stored as string and accessible like a python dict
self.model_name = params["model_name"]
self.dataset = params["dataset"]
self.input_size = int(params["input_size"])
self.model_weight_file = params["model_weight_file"]
self.class_file = params["class_file"]
def get_values(self):
# Send parameters values to Ikomia application
# Create the specific dict structure (string container)
params = {
"model_name": self.model_name,
"dataset": self.dataset,
"input_size": str(self.input_size),
"model_weight_file": self.model_weight_file,
"class_file": self.class_file}
return params
# --------------------
# - Class which implements the process
# - Inherits core.CProtocolTask or derived from Ikomia API
# --------------------
class Resnet(dataprocess.CClassificationTask):
def __init__(self, name, param):
dataprocess.CClassificationTask.__init__(self, name)
self.model = None
# Detect if we have a GPU available
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Create parameters class
if param is None:
self.set_param_object(ResnetParam())
else:
self.set_param_object(copy.deepcopy(param))
self.model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "weights")
def get_progress_steps(self):
# Function returning the number of progress steps for this process
# This is handled by the main progress bar of Ikomia application
return 2
def predict(self, image, input_size):
input_img = cv2.resize(image, (input_size, input_size))
trs = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = trs(input_img).to(self.device)
input_tensor = input_tensor.unsqueeze(0)
prob = None
with torch.no_grad():
output = self.model(input_tensor)
prob = torch.nn.functional.softmax(output[0], dim=0)
return prob
def run(self):
# Core function of your process
# Call beginTaskRun for initialization
self.begin_task_run()
# Get parameters :
param = self.get_param_object()
# Step progress bar:
self.emit_step_progress()
# Load model
if self.model is None or param.update:
# Load class names
self.read_class_names(param.class_file)
# Load model
if param.model_weight_file != "":
if os.path.isfile(param.model_weight_file):
param.dataset = "Custom"
use_torchvision = param.dataset != "Custom"
torch_dir_ori = torch.hub.get_dir()
torch.hub.set_dir(self.model_folder)
self.model = models.resnet(model_name=param.model_name,
use_pretrained=use_torchvision,
classes=len(self.get_names()))
if param.dataset == "Custom":
self.model.load_state_dict(torch.load(param.model_weight_file, map_location=self.device))
self.model.to(self.device)
torch.hub.set_dir(torch_dir_ori)
param.update = False
if self.is_whole_image_classification():
image_in = self.get_input(0)
src_image = image_in.get_image()
predictions = self.predict(src_image, param.input_size)
sorted_data = sorted(zip(predictions.flatten().tolist(), self.get_names()), reverse=True)
confidences = [str(conf) for conf, _ in sorted_data]
names = [name for _, name in sorted_data]
self.set_whole_image_results(names, confidences)
else:
input_objects = self.get_input_objects()
for obj in input_objects:
roi_img = self.get_object_sub_image(obj)
if roi_img is None:
continue
predictions = self.predict(roi_img, param.input_size)
class_index = predictions.argmax().item()
self.add_object(obj, class_index, predictions[class_index].item())
# Step progress bar:
self.emit_step_progress()
# Call endTaskRun to finalize process
self.end_task_run()
# --------------------
# - Factory class to build process object
# - Inherits dataprocess.CProcessFactory from Ikomia API
# --------------------
class ResnetFactory(dataprocess.CTaskFactory):
def __init__(self):
dataprocess.CTaskFactory.__init__(self)
# Set process information as string here
self.info.name = "infer_torchvision_resnet"
self.info.short_description = "ResNet inference model for image classification."
self.info.authors = "Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun"
self.info.article = "Deep Residual Learning for Image Recognition"
self.info.journal = "Conference on Computer Vision and Pattern Recognition (CVPR)"
self.info.year = 2016
self.info.license = "BSD-3-Clause License"
self.info.documentation_link = "https://arxiv.org/abs/1512.03385"
self.info.repository = "https://github.com/Ikomia-hub/infer_torchvision_resnet"
self.info.original_repository = "https://github.com/pytorch/vision"
# relative path -> as displayed in Ikomia application process tree
self.info.path = "Plugins/Python/Classification"
self.info.icon_path = "icons/pytorch-logo.png"
self.info.version = "1.2.1"
self.info.keywords = "residual,cnn,classification"
self.info.algo_type = core.AlgoType.INFER
self.info.algo_tasks = "CLASSIFICATION"
def create(self, param=None):
# Create process object
return Resnet(self.info.name, param)