-
Notifications
You must be signed in to change notification settings - Fork 1
/
DetReward.lua
72 lines (63 loc) · 2.54 KB
/
DetReward.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
------------------------------------------------------------------------
--[[ DetReward ]]--
-- Variance reduced detection reinforcement criterion.
-- input : {detection prediction, baseline reward}
-- Reward is 1 for success, Reward is 0 otherwise.
-- reward = scale*(Reward - baseline) where baseline is 2nd input element
-- Note : for RNNs with R = 1 for last step in sequence, encapsulate it
-- in nn.ModuleCriterion(DetReward, nn.SelectTable(-1))
------------------------------------------------------------------------
require 'utils'
local DetReward, parent = torch.class("nn.DetReward", "nn.Criterion")
function DetReward:__init(module, scale, criterion)
parent.__init(self)
self.module = module -- so it can call module:reinforce(reward) agent?!
self.scale = scale or 1 -- scale of reward
self.criterion = criterion or nn.MSECriterion() -- baseline criterion
self.sizeAverage = true
self.gradInput = {torch.Tensor()}
self.threshold = threshold or 0.3
end
function DetReward:updateOutput(inputTable, target)
assert(torch.type(inputTable) == 'table')
local input = inputTable[1]
local batch_size = #target
self.det_scores = torch.Tensor(batch_size):zero()
for b = 1, batch_size do
local pred = torch.Tensor(2)
pred:copy(input[b]):resize(1, 2)
local gts = target[b]
local overlap = utils.interval_overlap(gts, pred)
local max_ov, max_idx = torch.max(overlap, 1)
if max_ov[1][1] > self.threshold then
self.det_scores[b] = 1
end
end
self.reward = self.det_scores:mul(self.scale)
-- loss = -sum(reward)
self.output = -self.reward:sum()
if self.sizeAverage then
self.output = self.output/input:size(1)
end
return self.output
end
function DetReward:updateGradInput(inputTable, target)
local input = inputTable[1]
local baseline = inputTable[2]
-- reduce variance of reward using baseline
self.detReward = self.detReward or self.reward.new()
self.detReward:resizeAs(self.reward):copy(self.reward)
self.detReward:add(-1, baseline)
if self.sizeAverage then
self.detReward:div(input:size(1))
end
-- broadcast reward to modules
self.module:reinforce(self.detReward)
-- zero gradInput (this criterion has no gradInput for class pred)
self.gradInput[1]:resizeAs(input):zero()
-- learn the baseline reward
self.criterion:forward(baseline, self.reward)
self.gradInput[2] = self.criterion:backward(baseline, self.reward)
--self.gradInput[2] = self:fromBatch(self.gradInput[2], 1)
return self.gradInput
end