/
setup_and_run_model.py
executable file
·139 lines (117 loc) · 4.71 KB
/
setup_and_run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torchvision
from PIL import Image
def preload():
'''pre-load VGG model weights, for transfer learning. Automatically cached for later use.'''
torchvision.models.vgg16(pretrained=True)
class PoseEstimationNetwork(torch.nn.Module):
"""
PoseEstimationNetwork: Neural network based on the VGG16 neural network
architecture. The model is a little bit different from the original one
but we still import the model as it has already been trained on a huge
dataset (ImageNet) and even if we change a bit its architecture, the main
body of it is unchanged and the weights of the final model will not be too
far from the original one. We call this method "transfer learning".
The network is composed by two branches: one for the translation
(prediction of a 3 dimensional vector corresponding to x, y, z coordinates) and
one for the orientation (prediction of a 4 dimensional vector corresponding to
a quaternion)
"""
def __init__(self, *, is_symetric):
super(PoseEstimationNetwork, self).__init__()
self.is_symetric = is_symetric
self.model_backbone = torchvision.models.vgg16(pretrained=True) # uses cache
# remove the original classifier
self.model_backbone.classifier = torch.nn.Identity()
self.translation_block = torch.nn.Sequential(
torch.nn.Linear(25088, 256),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(256, 64),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(64, 3),
)
self.orientation_block = torch.nn.Sequential(
torch.nn.Linear(25088, 256),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(256, 64),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(64, 4),
LinearNormalized(),
)
def forward(self, x):
x = self.model_backbone(x)
output_translation = self.translation_block(x)
if (self.is_symetric == False):
output_orientation = self.orientation_block(x)
return output_translation, output_orientation
return output_translation
class LinearNormalized(torch.nn.Module):
"""
Custom activation function which normalizes the input.
It will be used to normalized the output of the orientation
branch in our model because a quaternion vector is a
normalized vector
"""
def __init__(self):
super(LinearNormalized, self).__init__()
def forward(self, x):
return self._linear_normalized(x)
def _linear_normalized(self, x):
"""
Activation function which normalizes an input
It will be used in the orientation network because
a quaternion is a normalized vector.
Args:
x (pytorch tensor with shape (batch_size, 4)): the input of the model
Returns:
a pytorch tensor normalized vector with shape(batch_size, 4)
"""
norm = torch.norm(x, p=2, dim=1).unsqueeze(0)
for index in range(norm.shape[1]):
if norm[0, index].item() == 0.0:
norm[0, index] = 1.0
x = torch.transpose(x, 0, 1)
x = torch.div(x, norm)
return torch.transpose(x, 0, 1)
def pre_process_image(path_image, device):
image_origin = Image.open(path_image).convert("RGB")
transform = get_transform()
image = [transform(image_origin).unsqueeze(0)]
image = list(img.to(device) for img in image)
return image
def get_transform():
"""
Apply a transform on the input image tensor
Returns:
https://pytorch.org/docs/stable/torchvision/transforms.html
"""
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(
(
224,
224,
)
),
torchvision.transforms.ToTensor(),
]
)
return transform
global model
model = None
def run_model_main(image_file_png, model_file_name):
global model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if model is None:
checkpoint = torch.load(model_file_name, map_location=device)
model = PoseEstimationNetwork(is_symetric=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
image = pre_process_image(image_file_png, device)
output_translation, output_orientation = model(torch.stack(image).reshape(-1, 3, 224, 224).to(device))
output_translation, output_orientation = output_translation.cpu().detach().numpy(), output_orientation.cpu().detach().numpy()
return output_orientation, output_translation