Skip to content

Commit

Permalink
Add final per class non-maximum supperession
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Sep 21, 2015
1 parent 1f94e34 commit 128b857
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 31 deletions.
5 changes: 5 additions & 0 deletions Rect.lua
Expand Up @@ -87,6 +87,11 @@ function Rect:contains(otherRect)
return self:containsPt(otherRect.minX, otherRect.minY) and self:containsPt(otherRect.maxX, otherRect.maxY)
end

function Rect:overlaps(other)
return self.minX < other.maxX and self.maxX > other.minX
and self.minY < other.maxY and self.maxY > other.minY
end

function Rect:normalize()
local l, t, r, b
if self.minX <= self.maxX then
Expand Down
65 changes: 41 additions & 24 deletions main.lua
Expand Up @@ -141,7 +141,7 @@ function extract_roi_pooling_input(input_rect, localizer, feature_layer_output)
-- the use of math.min ensures correct handling of empty rects,
-- +1 offset for top, left only is conversion from half-open 0-based interval
local s = feature_layer_output:size()
r = r:clip(Rect.new(0,0,s[3],s[2]))
r = r:clip(Rect.new(0, 0, s[3], s[2]))
local idx = { {}, { math.min(r.minY + 1, r.maxY), r.maxY }, { math.min(r.minX + 1, r.maxX), r.maxX } }
return feature_layer_output[idx], idx
end
Expand Down Expand Up @@ -378,7 +378,7 @@ function precompute_positive_list(out_fn, positive_threshold, negative_threshold
save_obj(out_fn, training_data)
end

function graph_evaluate(training_data_filename, network_filename, normalize)
function graph_evaluate(training_data_filename, network_filename, normalize, bgclass)
local training_data = load_obj(training_data_filename)
local ground_truth = training_data.ground_truth
local image_file_names = training_data.image_file_names
Expand Down Expand Up @@ -420,6 +420,7 @@ function graph_evaluate(training_data_filename, network_filename, normalize)
-- load image
local input = load_image_auto_size(fn, training_data.target_smaller_side, training_data.max_pixel_size, 'yuv')
local input_size = input:size()
local input_rect = Rect.new(0, 0, input_size[3], input_size[2])
input = normalize_image(input):cuda()

-- pass image through network
Expand Down Expand Up @@ -448,8 +449,8 @@ function graph_evaluate(training_data_filename, network_filename, normalize)

-- classification
local c = lsm:forward(cls_out)
if math.exp(c[1]) > 0.9 then
table.insert(matches, { p=c[1], a=a, r=r, l=i })
if math.exp(c[1]) > 0.95 and r:overlaps(input_rect) then
table.insert(matches, { p=c[1], a=a, r=r, l=i })
end

end
Expand All @@ -459,7 +460,6 @@ function graph_evaluate(training_data_filename, network_filename, normalize)

local winners = {}

print(#matches)
if #matches > 0 then

-- NON-MAXIMUM SUPPRESSION
Expand All @@ -470,50 +470,67 @@ function graph_evaluate(training_data_filename, network_filename, normalize)

local iou_threshold = 0.5
local pick = nms(bb, iou_threshold, 'area')

pick:apply(function (x) table.insert(winners, matches[x]) end )
local candidates = {}
pick:apply(function (x) table.insert(candidates, matches[x]) end )

-- REGION CLASSIFICATION

cnet:evaluate()

-- create cnet input batch
local cinput = torch.CudaTensor(#winners, 7 * 7 * 300)
for i,v in ipairs(winners) do
local cinput = torch.CudaTensor(#candidates, 7 * 7 * 300)
for i,v in ipairs(candidates) do
-- pass through adaptive max pooling operation
local pi, idx = extract_roi_pooling_input(v.r, localizer, outputs[5])
local po = amp:forward(pi):view(7 * 7 * 300)
cinput[i] = po:clone()
cinput[i] = amp:forward(pi):view(7 * 7 * 300)
end

-- send extracted roi-data through classification network
local coutputs = cnet:forward(cinput)

-- compute classification and regression error and run backward pass
local bbox_out = coutputs[1]
local cls_out = coutputs[2]

for i=1,#winners do
winners[i].r2 = Anchors.anchorToInput(winners[i].a, bbox_out[i])
local yclass = {}
for i,x in ipairs(candidates) do
x.r2 = Anchors.anchorToInput(x.r, bbox_out[i])

local cprob = cls_out[i]
local p,c = torch.sort(cprob, 1, true) -- get probabilities and class indicies

winners[i].class = c[1]
winners[i].confidence = p[1]
x.class = c[1]
x.confidence = p[1]

if x.class ~= bgclass and math.exp(x.confidence) > 0.2 then
if not yclass[x.class] then
yclass[x.class] = {}
end

table.insert(yclass[x.class], x)
end
end

-- run per class NMS
for i,c in pairs(yclass) do
-- fill rect tensor
bb = torch.Tensor(#c, 5)
for j,r in ipairs(c) do
bb[{j, {1,4}}] = r.r2:totensor()
bb[{j, 5}] = r.confidence
end

pick = nms(bb, 0.1, bb[{{}, 5}])
pick:apply(function (x) table.insert(winners, c[x]) end )

end

end

-- load image back to rgb-space before drawing rectangles
local img = load_image_auto_size(fn, training_data.target_smaller_side, training_data.max_pixel_size, 'rgb')

for i,m in ipairs(winners) do
local color
if m.class ~= 17 and math.exp(m.confidence) > 0.25 then
draw_rectangle(img, m.r, blue)
--draw_rectangle(img, m.r2, green)
end
--draw_rectangle(img, m.r, blue)
draw_rectangle(img, m.r2, green)
end

image.saveJPG(string.format('dummy%d.jpg', n), img)
Expand Down Expand Up @@ -586,5 +603,5 @@ function graph_training(training_data_filename, network_filename)
end

--precompute_positive_list('training_data.t7', 0.6, 0.3)
graph_training('training_data.t7')
--graph_evaluate('training_data.t7', 'full2_022000.t7', true)
--graph_training('training_data.t7')
graph_evaluate('training_data.t7', 'full2_026000.t7', true, 17)
15 changes: 8 additions & 7 deletions nms.lua
Expand Up @@ -34,14 +34,15 @@ function nms(boxes, overlap, scores)

local area = torch.cmul(x2 - x1 + 1, y2 - y1 + 1)

local v, I
if scores == 'area' then
v, I = area:sort(1)
elseif scores then
v, I = scores:sort()
else -- use max_y
v, I = y2:sort(1)
if type(scores) == 'number' then
scores = boxes[{{}, scores}]
elseif scores == 'area' then
scores = area
else
scores = y2 -- use max_y
end

local v, I = scores:sort(1)

pick:resize(area:size()):zero()
local count = 1
Expand Down

0 comments on commit 128b857

Please sign in to comment.