forked from nightrome/matconvnet-calvin
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Training of part and object baseline
- Loading branch information
Abel Gonzalez
committed
Apr 27, 2017
1 parent
b61c315
commit 2843e5d
Showing
24 changed files
with
2,497 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
% Demo Joint object and part detection | ||
% | ||
% It prepares all the necessary structures and then trains and test several | ||
% networks | ||
% | ||
|
||
|
||
% Add folders to path | ||
setup(); | ||
|
||
% Download datasets | ||
downloadVOC2010(); | ||
|
||
downloadPASCALParts(); | ||
|
||
% Download base network | ||
downloadNetwork('modelName','imagenet-caffe-alex'); | ||
|
||
% Download Selective Search | ||
downloadSelectiveSearch(); | ||
|
||
% Create structures with part and object info | ||
setupParts(); | ||
|
||
% Train and test baseline network | ||
|
||
% Add 'if' here to decide whether train baseline net or download | ||
calvinNNPartDetection(); | ||
|
||
% Train and test our model, using as input baseline network | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
matconvnet-calvin/examples/parts/DetectionPartsToPascalVOCFiles.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
function [recall, prec, ap, apUpperBound] = DetectionPartsToPascalVOCFiles(set, idxPart, idxClass, boxes, boxIms, boxClfs, compName, doEval, overlapNms) | ||
|
||
% Filters overlapping boxes (near duplicates), creates official VOC | ||
% detection files. Evaluates results. | ||
|
||
global DATAopts; | ||
|
||
DATAopts.testset = set; | ||
|
||
if ~exist('doEval', 'var') | ||
doEval = 0; | ||
end | ||
|
||
|
||
partName = DATAopts.prt_classes{idxPart}; | ||
objName = DATAopts.classes{idxClass}; | ||
% Sort scores/boxes/images | ||
[boxClfs, sI] = sort(boxClfs, 'descend'); | ||
boxIms = boxIms(sI); | ||
boxes = boxes(sI,:); | ||
|
||
% Filter boxes if wanted | ||
if exist('overlapNms', 'var') && overlapNms > 0 | ||
[uIms, ~, uN] = unique(boxIms); | ||
keepIds = true(size(boxes,1), 1); | ||
fprintf('Filtering %d: ', length(uIms)); | ||
for i=1:length(uIms) | ||
if mod(i,500) == 0 | ||
fprintf('%d ', i); | ||
end | ||
currIds = find(uN == i); | ||
[~, goodBoxesI] = BoxNMS(boxes(currIds,:), overlapNms); | ||
keepIds(currIds) = goodBoxesI; | ||
end | ||
boxClfs = boxClfs(keepIds); | ||
boxIms = boxIms(keepIds); | ||
boxes = boxes(keepIds,:); | ||
fprintf('\n'); | ||
end | ||
|
||
|
||
|
||
% Save detection results using detection results | ||
savePath = fullfile(DATAopts.resdir, 'Main', ['%s_det_', set, '_%s.txt']); | ||
resultsName = sprintf(savePath, compName, [objName '-' partName]); | ||
fid = fopen(resultsName,'w'); | ||
for j=1:length(boxIms) | ||
fprintf(fid,'%s %f %f %f %f %f\n', boxIms{j}, boxClfs(j),boxes(j,:)); | ||
end | ||
fclose(fid); | ||
fprintf('\n'); | ||
|
||
if doEval | ||
[recall, prec, ap] = VOCevaldetParts_modified(DATAopts, partName, objName, resultsName, false); | ||
apUpperBound = max(recall); | ||
else | ||
recall = 0; | ||
prec = 0; | ||
ap = 0; | ||
apUpperBound = 0; | ||
end |
141 changes: 141 additions & 0 deletions
141
matconvnet-calvin/examples/parts/VOC_modified/VOCevaldetParts_modified.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
function [rec,prec,ap] = VOCevaldetParts_modified(DATAopts, cls, obj, loadName, draw, flipBoxes) | ||
|
||
% load test set | ||
tic; | ||
|
||
gtids = GetImagesPlusLabels(DATAopts.testset); | ||
|
||
for i=1:length(gtids) | ||
% display progress | ||
if toc>1 | ||
fprintf('%s: pr: load: %d/%d\n',cls,i,length(gtids)); | ||
drawnow; | ||
tic; | ||
end | ||
|
||
% % Create annotation struct as if it was being read from file | ||
recs(i).objects.class = []; | ||
recs(i).objects.bbox = []; | ||
recs(i).objects.difficult = []; | ||
% | ||
% add parts belonging to object class | ||
objIm = DATAopts.imdbTest.parts{i}; | ||
idxObj = 1; | ||
for kk = 1:size(objIm) | ||
if strcmp(DATAopts.imdbTest.objects{i}.class(kk), obj) && ~isempty(objIm{kk}) | ||
for ll = 1:size(objIm{kk}.class_id,1) | ||
recs(i).objects(idxObj).class = DATAopts.imdbTest.prt_classes{DATAopts.imdbTest.obj_class2id(obj)}{objIm{kk}.class_id(ll)}; | ||
recs(i).objects(idxObj).bbox = objIm{kk}.bbox(ll,:); | ||
recs(i).objects(idxObj).difficult = objIm{kk}.difficult(ll); | ||
idxObj = idxObj + 1; | ||
end | ||
end | ||
end | ||
end | ||
|
||
|
||
fprintf('%s: pr: evaluating detections\n',cls); | ||
|
||
% hash image ids | ||
hash=VOChash_init_modified(gtids); | ||
|
||
% extract ground truth objects | ||
|
||
npos=0; | ||
gt(length(gtids))=struct('BB',[],'diff',[],'det',[]); | ||
for i=1:length(gtids) | ||
% extract parts of class | ||
clsinds=strcmp(cls,{recs(i).objects.class}); | ||
gt(i).BB=cat(1,recs(i).objects(clsinds).bbox)'; | ||
gt(i).diff=[recs(i).objects(clsinds).difficult]; | ||
gt(i).det=false(length(clsinds),1); | ||
npos=npos+sum(~gt(i).diff); | ||
end | ||
|
||
% load results | ||
% [ids,confidence,b1,b2,b3,b4]=textread(sprintf(DATAopts.detrespath,id,cls),'%s %f %f %f %f %f'); | ||
[ids,confidence,b1,b2,b3,b4]=textread(loadName,'%s %f %f %f %f %f'); | ||
|
||
if exist('flipBoxes', 'var') && flipBoxes == true | ||
BB=[b2 b1 b4 b3]'; | ||
else | ||
BB=[b1 b2 b3 b4]'; | ||
end | ||
|
||
% sort detections by decreasing confidence | ||
[sc,si]=sort(-confidence); | ||
ids=ids(si); | ||
BB=BB(:,si); | ||
|
||
% assign detections to ground truth objects | ||
nd=length(confidence); | ||
tp=zeros(nd,1); | ||
fp=zeros(nd,1); | ||
tic; | ||
for d=1:nd | ||
% display progress | ||
if toc>1 | ||
fprintf('%s: pr: compute: %d/%d\n',cls,d,nd); | ||
drawnow; | ||
tic; | ||
end | ||
|
||
% find ground truth image | ||
i=VOChash_lookup_modified(hash,ids{d}); | ||
if isempty(i) | ||
error('unrecognized image "%s"',ids{d}); | ||
elseif length(i)>1 | ||
error('multiple image "%s"',ids{d}); | ||
end | ||
|
||
% assign detection to ground truth object if any | ||
bb=BB(:,d); | ||
ovmax=-inf; | ||
for j=1:size(gt(i).BB,2) | ||
bbgt=gt(i).BB(:,j); | ||
bi=[max(bb(1),bbgt(1)) ; max(bb(2),bbgt(2)) ; min(bb(3),bbgt(3)) ; min(bb(4),bbgt(4))]; | ||
iw=bi(3)-bi(1)+1; | ||
ih=bi(4)-bi(2)+1; | ||
if iw>0 & ih>0 | ||
% compute overlap as area of intersection / area of union | ||
ua=(bb(3)-bb(1)+1)*(bb(4)-bb(2)+1)+... | ||
(bbgt(3)-bbgt(1)+1)*(bbgt(4)-bbgt(2)+1)-... | ||
iw*ih; | ||
ov=iw*ih/ua; | ||
if ov>ovmax | ||
ovmax=ov; | ||
jmax=j; | ||
end | ||
end | ||
end | ||
% assign detection as true positive/don't care/false positive | ||
if ovmax>=DATAopts.minoverlap | ||
if ~gt(i).diff(jmax) | ||
if ~gt(i).det(jmax) | ||
tp(d)=1; % true positive | ||
gt(i).det(jmax)=true; | ||
else | ||
fp(d)=1; % false positive (multiple detection) | ||
end | ||
end | ||
else | ||
fp(d)=1; % false positive | ||
end | ||
end | ||
|
||
% compute precision/recall | ||
fp=cumsum(fp); | ||
tp=cumsum(tp); | ||
rec=tp/npos; | ||
prec=tp./(fp+tp); | ||
|
||
ap=VOCap(rec,prec); | ||
|
||
if draw | ||
% plot precision/recall | ||
plot(rec,prec,'-'); | ||
grid; | ||
xlabel 'recall' | ||
ylabel 'precision' | ||
title(sprintf('class: %s, subset: %s, AP = %.3f',cls,DATAopts.testset,ap)); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
% function calvinNNDetection() | ||
% | ||
% Copyright by Holger Caesar, 2016 | ||
|
||
% Global variables | ||
global glDatasetFolder glFeaturesFolder; | ||
assert(~isempty(glDatasetFolder) && ~isempty(glFeaturesFolder)); | ||
|
||
%%% Settings | ||
% Dataset | ||
vocYear = 2010; | ||
trainName = 'train'; | ||
testName = 'val'; | ||
vocName = sprintf('VOC%d', vocYear); | ||
datasetDir = [fullfile(glDatasetFolder, vocName), '/']; | ||
|
||
|
||
% Specify paths | ||
outputFolder = fullfile(glFeaturesFolder, 'CNN-Models', 'Parts', vocName, sprintf('%s-testRelease', vocName)); | ||
netPath = fullfile(glFeaturesFolder, 'CNN-Models', 'matconvnet', 'imagenet-caffe-alex.mat'); | ||
logFilePath = fullfile(outputFolder, 'log.txt'); | ||
|
||
% Fix randomness | ||
randSeed = 42; | ||
rng(randSeed); | ||
|
||
% Setup dataset specific options and check validity | ||
setupDataOptsPrts(vocYear, testName, datasetDir); | ||
global DATAopts; % Database specific paths | ||
assert(~isempty(DATAopts), 'Error: Dataset not initialized properly!'); | ||
|
||
|
||
% Task-specific | ||
nnOpts.testFn = @testPartDetection; | ||
nnOpts.misc.overlapNms = 0.3; | ||
% Objectives for both parts and objects | ||
nnOpts.derOutputs = {'objectivePrt', 1, 'objectiveObj', 1, 'regressObjectivePrt', 1, 'regressObjectiveObj', 1}; | ||
|
||
% General | ||
nnOpts.batchSize = 2; | ||
nnOpts.numSubBatches = nnOpts.batchSize; % 1 image per sub-batch | ||
nnOpts.weightDecay = 5e-4; | ||
nnOpts.momentum = 0.9; | ||
nnOpts.numEpochs = 16; | ||
nnOpts.learningRate = [repmat(1e-3, 12, 1); repmat(1e-4, 4, 1)]; | ||
nnOpts.misc.netPath = netPath; | ||
nnOpts.expDir = outputFolder; | ||
nnOpts.convertToTrain = 0; % perform explicit conversion to our architecure | ||
nnOpts.fastRcnn = 0; | ||
nnOpts.bboxRegress = 1; | ||
nnOpts.gpus = []; % for automatic selection use: SelectIdleGpu(); | ||
|
||
% Create outputFolder | ||
if ~exist(outputFolder, 'dir') | ||
mkdir(outputFolder); | ||
end | ||
|
||
% Start logging | ||
diary(logFilePath); | ||
|
||
%%% Setup | ||
% Start from pretrained network | ||
net = load(nnOpts.misc.netPath); | ||
|
||
% Setup imdb | ||
imdb = setupImdbPartDetection(trainName, testName, net); | ||
|
||
% Create calvinNN CNN class | ||
% Do not transform into fast-rcnn with bbox regression | ||
calvinn = CalvinNN(net, imdb, nnOpts); | ||
|
||
% Perform here the conversion to part/obj architecture | ||
calvinn.convertNetworkToPrtObjFastRcnn; | ||
|
||
%%% Train | ||
calvinn.train(); | ||
|
||
%%% Test | ||
stats = calvinn.testPrtObj(); | ||
|
||
% TEST EVAL CODE WITH LOADED STATS | ||
%%% Eval | ||
evalPartAndObjectDetection(testName, stats, nnOpts); |
Oops, something went wrong.