Skip to content
Browse files

Added option for squared norm

  • Loading branch information...
1 parent be844ec commit f1d684511f47feb33c10ec7da633af95bb9e0391 @abeschneider committed May 31, 2012
Showing with 36 additions and 12 deletions.
  1. +25 −11 Norm.lua
  2. +11 −1 test.lua
View
36 Norm.lua
@@ -1,27 +1,41 @@
require 'nn'
local Norm, parent = torch.class('Norm', 'nn.Module')
-function Norm:__init()
+function Norm:__init(squared_norm)
parent.__init(self)
self.gradInput = torch.Tensor()
self.output = torch.Tensor(1)
- self.norm = p
+ self.squared_norm = squared_norm or false
end
function Norm:updateOutput(input)
- -- \sqrt{\sum_k x_k^{2}}
- self.output[1] = input:norm(2)
- return self.output
+ if not self.squared_norm then
+ -- \sqrt{\sum_k x_k^{2}}
+ self.output[1] = input:norm(2)
+ else
+ -- \sum_k x_k^{2}
+ self.output[1] = torch.sum(input:clone():pow(2))
+ return self.output
+ end
+
+ return self.output
end
function Norm:updateGradInput(input, gradOutput)
- -- derivative of \sqrt(\sum_k x_k^{2})
- -- = \frac{x_i}{\sqrt{\sum x_k^{2}}}
- local tmp = input:clone():pow(2)
- local div = torch.sqrt(torch.sum(tmp))
- self.gradInput:resizeAs(input):copy(input):div(div)
- self.gradInput:mul(gradOutput[1]);
+ if not self.squared_norm then
+ -- derivative of \sqrt(\sum_k x_k^{2})
+ -- = \frac{x_i}{\sqrt{\sum x_k^{2}}}
+ local tmp = input:clone():pow(2)
+ local div = torch.sqrt(torch.sum(tmp))
+ self.gradInput:resizeAs(input):copy(input):div(div)
+ self.gradInput:mul(gradOutput[1]);
+ else
+ -- derivative of \sum_k x_k^{2}
+ -- = 2*x_i
+ self.gradInput:resizeAs(input):copy(input)
+ self.gradInput:mul(2)
+ end
return self.gradInput
end
View
12 test.lua
@@ -6,13 +6,23 @@ tester = torch.Tester()
mytest = {}
function mytest.TestNorm()
- local module = Norm(2)
+ local module = Norm()
local input = torch.Tensor(10, 1):zero()
local err = nn.Jacobian.testJacobian(module, input)
print(err)
tester:assertlt(err, precision, 'error on state ')
end
+function mytest.TestNormSquared()
+ local module = Norm(true)
+
+ local input = torch.Tensor(10, 1):zero()
+ local err = nn.Jacobian.testJacobian(module, input)
+ print(err)
+ tester:assertlt(err, precision, 'error on state ')
+end
+
+
tester:add(mytest)
tester:run()

1 comment on commit f1d6845

@Atcold
Atcold commented on f1d6845 May 5, 2015

If I'm not mistaken, this works only for single samples and not for batches, right?

Please sign in to comment.
Something went wrong with that request. Please try again.