Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f729c3f
commit b89d64d
Showing
6 changed files
with
903 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
require 'nn' | ||
|
||
local HingeEmbeddingCriterionEx, parent = torch.class('HingeEmbeddingCriterionEx', 'nn.Criterion') | ||
|
||
function HingeEmbeddingCriterionEx:__init(margin) | ||
parent.__init(self) | ||
|
||
self.margin = margin or 1 | ||
self.sizeAverage = true | ||
end | ||
|
||
function HingeEmbeddingCriterionEx:updateOutput(input,y) | ||
y=-1 | ||
self.buffer = self.buffer or input.new() | ||
if not torch.isTensor(y) then | ||
self.ty = self.ty or input.new():resize(1) | ||
self.ty[1]=y | ||
y=self.ty | ||
end | ||
|
||
self.buffer:resizeAs(input):copy(input) | ||
--self.buffer[torch.eq(y, -1)] = 0 | ||
--self.output = self.buffer:sum() | ||
|
||
self.buffer:fill(self.margin):add(-1, input) | ||
self.buffer:cmax(0) | ||
--self.buffer[torch.eq(y, 1)] = 0 | ||
--self.output = self.output + self.buffer:sum() | ||
|
||
self.output = self.buffer | ||
|
||
if (self.sizeAverage == nil or self.sizeAverage == true) then | ||
self.output = self.output:sum()/input:nElement() | ||
end | ||
|
||
|
||
return self.output | ||
end | ||
|
||
|
||
function HingeEmbeddingCriterionEx:updateGradInput(input, y) | ||
y=-1 | ||
-- if not torch.isTensor(y) then self.ty[1]=y; y=self.ty end | ||
self.gradInput:resizeAs(input):fill(y) | ||
--self.gradInput[torch.cmul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0 | ||
local tempmargin = input:clone():fill(self.margin) | ||
indtemp = input.new() | ||
indtemp = input.new():resize(input:size()):copy(torch.gt(tempmargin, input)) | ||
self.gradInput = self.gradInput:cmul(indtemp) | ||
|
||
if (self.sizeAverage == nil or self.sizeAverage == true) then | ||
self.gradInput:mul(1 / input:nElement()) | ||
end | ||
|
||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
------------------------------------------------------------------------ | ||
--[[ SpatialGlimpseTwoDimbhwc ]]-- | ||
|
||
------------------------------------------------------------------------ | ||
local SpatialGlimpseTwoDimbhwc, parent = torch.class("SpatialGlimpseTwoDimbhwc", "nn.Module") | ||
|
||
function SpatialGlimpseTwoDimbhwc:__init(size, depth, scale) | ||
require 'nnx' | ||
--print(torch.type(size)) | ||
if torch.type(size)=='table' then | ||
self.height = size[1] | ||
self.width = size[2] | ||
else | ||
self.width = size | ||
self.height = size | ||
end | ||
self.depth = depth or 3 | ||
self.scale = scale or 2 | ||
--print(torch.type(self.width)) | ||
assert(torch.type(self.width) == 'number') | ||
assert(torch.type(self.height) == 'number') | ||
assert(torch.type(self.depth) == 'number') | ||
assert(torch.type(self.scale) == 'number') | ||
parent.__init(self) | ||
self.gradInput = {torch.Tensor(), torch.Tensor()} | ||
if self.scale == 2 then | ||
self.module = nn.SpatialAveragePooling(2,2,2,2) | ||
else | ||
self.module = nn.SpatialReSampling{oheight=self.height,owidth=self.width} | ||
end | ||
self.modules = {self.module} | ||
end | ||
|
||
function SpatialGlimpseTwoDimbhwc:updateOutput(inputTable) | ||
assert(torch.type(inputTable) == 'table') | ||
assert(#inputTable >= 2) | ||
local input, location = unpack(inputTable) | ||
input, location = self:toBatch(input, 3), self:toBatch(location, 1) | ||
assert(input:dim() == 4 and location:dim() == 2) | ||
|
||
--bchw | ||
self.output:resize(input:size(1), self.height, self.width, input:size(4)) | ||
outcoord = torch.clamp(torch.round(location*12), 1, 12) | ||
|
||
for sampleIdx=1,self.output:size(1) do | ||
local outputSample = self.output[sampleIdx] | ||
local inputSample = input[sampleIdx] | ||
y = outcoord[sampleIdx][1] | ||
x = outcoord[sampleIdx][2] | ||
--input is bhwc | ||
outputSample:copy(inputSample:narrow(1, y , self.height):narrow(2, x, self.width)) | ||
end | ||
self.output = self:fromBatch(self.output, 1) | ||
|
||
return self.output | ||
end | ||
|
||
function SpatialGlimpseTwoDimbhwc:updateGradInput(inputTable, gradOutput) | ||
local input, location = unpack(inputTable) | ||
if #self.gradInput ~= 2 then | ||
self.gradInput = {input.new(), input.new()} | ||
end | ||
local gradInput, gradLocation = unpack(self.gradInput) | ||
input, location = self:toBatch(input, 3), self:toBatch(location, 1) | ||
gradOutput = self:toBatch(gradOutput, 3) | ||
|
||
gradInput:resizeAs(input):zero() | ||
gradLocation:resizeAs(location):zero() -- no backprop through location | ||
|
||
gradOutput = gradOutput:view(input:size(1), self.height, self.width, input:size(4)) | ||
|
||
outcoord = torch.clamp(torch.round(location*12), 1, 12) | ||
|
||
for sampleIdx=1,gradOutput:size(1) do | ||
local gradOutputSample = gradOutput[sampleIdx] | ||
local gradInputSample = gradInput[sampleIdx] | ||
|
||
local inputSample = input[sampleIdx] | ||
y = outcoord[sampleIdx][1] | ||
x = outcoord[sampleIdx][2] | ||
local pad = gradInputSample:narrow(1, y, self.height):narrow(2, x, self.width) | ||
|
||
pad:copy(gradOutputSample) | ||
end | ||
|
||
self.gradInput[1] = self:fromBatch(gradInput, 1) | ||
self.gradInput[2] = self:fromBatch(gradLocation, 1) | ||
|
||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
require 'dp' | ||
|
||
local ConfusionEx, parent = torch.class("ConfusionEx", "dp.Feedback") | ||
ConfusionEx.isConfusion = true | ||
|
||
function ConfusionEx:__init(config) | ||
require 'optim' | ||
config = config or {} | ||
assert(torch.type(config) == 'table' and not config[1], | ||
"Constructor requires key-value arguments") | ||
local args, bce, name, target_dim, output_module = xlua.unpack( | ||
{config}, | ||
'Confusion', | ||
'Adapter for optim.ConfusionMatrix', | ||
{arg='bce', type='boolean', default=false, | ||
help='set true when using Binary Cross-Entropy (BCE)Criterion'}, | ||
{arg='name', type='string', default='confusion', | ||
help='name identifying Feedback in reports'}, | ||
{arg='target_dim', default=-1, type='number', | ||
help='row index of target label to be used to measure confusion'}, | ||
{arg='output_module', type='nn.Module', | ||
help='module applied to output before measuring confusion matrix'} | ||
) | ||
config.name = name | ||
self._bce = bce | ||
self._output_module = output_module or nn.Identity() | ||
self._target_dim = target_dim | ||
parent.__init(self, config) | ||
end | ||
|
||
function ConfusionEx:setup(config) | ||
parent.setup(self, config) | ||
self._mediator:subscribe("doneEpoch", self, "doneEpoch") | ||
end | ||
|
||
function ConfusionEx:doneEpoch(report) | ||
if self._cm and self._verbose then | ||
print(self._id:toString()..", instance accuracy = "..self._cm.totalValid..", class accuracy = "..self._cm.averageValid) | ||
end | ||
end | ||
|
||
function ConfusionEx:_add(batch, output, report) | ||
if self._output_module then | ||
output = self._output_module:updateOutput(output) | ||
end | ||
|
||
if not self._cm then | ||
if self._bce then | ||
self._cm = optim.ConfusionMatrix({0,1}) | ||
else | ||
self._cm = optim.ConfusionMatrix(batch:targets():classes()) | ||
end | ||
self._cm:zero() | ||
end | ||
|
||
local act = self._bce and output:view(-1) or output:view(output:size(1), -1) | ||
local tgt = batch:targets():forward('b') | ||
if self._target_dim >0 then | ||
tgt=tgt[self._target_dim] | ||
end | ||
|
||
if self._bce then | ||
self._act = self._act or act.new() | ||
self._tgt = self._tgt or tgt.new() | ||
-- round it to get a class | ||
-- add 1 to get indices starting at 1 | ||
self._act:gt(act, 0.5):add(1) | ||
self._tgt:add(tgt,1) | ||
act = self._act | ||
tgt = self._tgt | ||
end | ||
|
||
if not (torch.isTypeOf(act,'torch.FloatTensor') or torch.isTypeOf(act, 'torch.DoubleTensor')) then | ||
self._actf = self.actf or torch.FloatTensor() | ||
self._actf:resize(act:size()):copy(act) | ||
act = self._actf | ||
end | ||
|
||
self._cm:batchAdd(act, tgt) | ||
end | ||
|
||
function ConfusionEx:_reset() | ||
if self._cm then | ||
self._cm:zero() | ||
end | ||
end | ||
|
||
function ConfusionEx:report() | ||
local cm = self._cm or {} | ||
if self._cm then | ||
cm:updateValids() | ||
end | ||
--valid means accuracy | ||
--union means divide valid classification by sum of rows and cols | ||
-- (as opposed to just cols.) minus valid classificaiton | ||
-- (which is included in each sum) | ||
return { | ||
[self:name()] = { | ||
matrix = cm.mat, | ||
per_class = { | ||
accuracy = cm.valids, | ||
union_accuracy = cm.unionvalids, | ||
avg = { | ||
accuracy = cm.averageValid, | ||
union_accuracy = cm.averageUnionValid | ||
} | ||
}, | ||
accuracy = cm.totalValid, | ||
avg_per_class_accuracy = cm.averageValid, | ||
classes = cm.classes | ||
}, | ||
n_sample = self._n_sample | ||
} | ||
end | ||
|
Oops, something went wrong.