-
Notifications
You must be signed in to change notification settings - Fork 212
/
main.lua
126 lines (106 loc) · 4.23 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
require 'torch'
require 'nn'
require 'nngraph'
require 'paths'
require 'image'
require 'xlua'
local utils = require 'utils'
local opts = require 'opts'(arg)
-- Load optional libraries
require('cunn')
require('cudnn')
-- Load optional data-loading libraries
matio = xrequire('matio') -- matlab
npy4th = xrequire('npy4th') -- python numpy
local FaceDetector = require 'facedetection_dlib'
torch.setheaptracking(true)
torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(1)
local fileList, requireDetectionCnt = utils.getFileList(opts)
local predictions = {}
local faceDetector = nil
if requireDetectionCnt > 0 then faceDetector = FaceDetector() end
local model = torch.load(opts.model)
local modelZ
if opts.type == '3D-full' then
modelZ = torch.load(opts.modelZ)
if opts.device ~= 'cpu' then modelZ = modelZ:cuda() end
modelZ:evaluate()
end
if opts.device == 'gpu' then model = model:cuda() end
model:evaluate()
for i = 1, #fileList do
local img = image.load(fileList[i].image)
-- Convert grayscale to pseudo-rgb
if img:size(1)==1 then
img = torch.repeatTensor(img,3,1,1)
end
-- Detect faces, if needed
local detectedFaces, detectedFace
if fileList[i].scale == nil then
detectedFaces = faceDetector:detect(img)
if(#detectedFaces<1) then goto continue end -- When continue is missing
-- Compute only for the first face for now
fileList[i].center, fileList[i].scale = utils.get_normalisation(detectedFaces[1])
detectedFace = detectedFaces[1]
end
img = utils.crop(img, fileList[i].center, fileList[i].scale, 256):view(1,3,256,256)
if opts.device ~= 'cpu' then img = img:cuda() end
local output = model:forward(img)[4]:clone()
output:add(utils.flip(utils.shuffleLR(model:forward(utils.flip(img))[4])))
local preds_hm, preds_img = utils.getPreds(output, fileList[i].center, fileList[i].scale)
preds_hm = preds_hm:view(68,2):float()*4
-- depth prediction
if opts.type == '3D-full' then
out = torch.zeros(68, 256, 256)
for i=1,68 do
if preds_hm[i][1] > 0 then
utils.drawGaussian(out[i], preds_hm[i], 2)
end
end
out = out:view(1,68,256,256)
local inputZ = torch.cat(img:float(), out, 2)
if opts.device ~= 'cpu' then inputZ = inputZ:cuda() end
local depth_pred = modelZ:forward(inputZ):float():view(68,1)
preds_hm = torch.cat(preds_hm, depth_pred, 2)
preds_img = torch.cat(preds_img:view(68,2), depth_pred*(1/(256/(200*fileList[i].scale))),2)
end
if opts.mode == 'demo' then
if detectedFace ~= nil then
-- Converting it to the predicted space (for plotting)
detectedFace[{{3,4}}] = utils.transform(torch.Tensor({detectedFace[3],detectedFace[4]}), fileList[i].center, fileList[i].scale, 256)
detectedFace[{{1,2}}] = utils.transform(torch.Tensor({detectedFace[1],detectedFace[2]}), fileList[i].center, fileList[i].scale, 256)
detectedFace[3] = detectedFace[3]-detectedFace[1]
detectedFace[4] = detectedFace[4]-detectedFace[2]
end
utils.plot(img, preds_hm, detectedFace)
end
if opts.save then
local dest = opts.output..'/'..paths.basename(fileList[i].image, '.'..paths.extname(fileList[i].image))
if opts.outputFormat == 't7' then
torch.save(dest..'.t7', preds_img)
elseif opts.outputFormat == 'txt' then
-- csv without header
local out = torch.DiskFile(dest .. '.txt', 'w')
for i=1,68 do
if preds_img:size(2)==3 then
out:writeString(tostring(preds_img[{i,1}]) .. ',' .. tostring(preds_img[{i,2}]) .. ',' .. tostring(preds_img[{i,3}]) .. '\n')
else
out:writeString(tostring(preds_img[{i,1}]) .. ',' .. tostring(preds_img[{i,2}]) .. '\n')
end
end
out:close()
end
xlua.progress(i, #fileList)
end
if opts.mode == 'eval' then
predictions[i] = preds_img:clone() + 1.75
xlua.progress(i,#fileList)
end
::continue::
end
if opts.mode == 'eval' then
predictions = torch.cat(predictions,1)
local dists = utils.calcDistance(predictions,fileList)
utils.calculateMetrics(dists)
end