Skip to content

Commit

Permalink
Training of part and object baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
Abel Gonzalez committed Apr 27, 2017
1 parent b61c315 commit 2843e5d
Show file tree
Hide file tree
Showing 24 changed files with 2,497 additions and 1 deletion.
32 changes: 32 additions & 0 deletions demo_prts.m
@@ -0,0 +1,32 @@
% Demo Joint object and part detection
%
% It prepares all the necessary structures and then trains and test several
% networks
%


% Add folders to path
setup();

% Download datasets
downloadVOC2010();

downloadPASCALParts();

% Download base network
downloadNetwork('modelName','imagenet-caffe-alex');

% Download Selective Search
downloadSelectiveSearch();

% Create structures with part and object info
setupParts();

% Train and test baseline network

% Add 'if' here to decide whether train baseline net or download
calvinNNPartDetection();

% Train and test our model, using as input baseline network


@@ -1,4 +1,4 @@
function [scores, index] = BoxBestOverlap(targetBoxes, testBoxes)
function [scores, index] = BoxBestOverlapFastRcnn(targetBoxes, testBoxes)
% [scores, index] = BoxBestOverlap(targetBoxes, testBoxes)
%
% Get overlap scores (Pascal-wise) for testBoxes bounding boxes
Expand Down
61 changes: 61 additions & 0 deletions matconvnet-calvin/examples/parts/DetectionPartsToPascalVOCFiles.m
@@ -0,0 +1,61 @@
function [recall, prec, ap, apUpperBound] = DetectionPartsToPascalVOCFiles(set, idxPart, idxClass, boxes, boxIms, boxClfs, compName, doEval, overlapNms)

% Filters overlapping boxes (near duplicates), creates official VOC
% detection files. Evaluates results.

global DATAopts;

DATAopts.testset = set;

if ~exist('doEval', 'var')
doEval = 0;
end


partName = DATAopts.prt_classes{idxPart};
objName = DATAopts.classes{idxClass};
% Sort scores/boxes/images
[boxClfs, sI] = sort(boxClfs, 'descend');
boxIms = boxIms(sI);
boxes = boxes(sI,:);

% Filter boxes if wanted
if exist('overlapNms', 'var') && overlapNms > 0
[uIms, ~, uN] = unique(boxIms);
keepIds = true(size(boxes,1), 1);
fprintf('Filtering %d: ', length(uIms));
for i=1:length(uIms)
if mod(i,500) == 0
fprintf('%d ', i);
end
currIds = find(uN == i);
[~, goodBoxesI] = BoxNMS(boxes(currIds,:), overlapNms);
keepIds(currIds) = goodBoxesI;
end
boxClfs = boxClfs(keepIds);
boxIms = boxIms(keepIds);
boxes = boxes(keepIds,:);
fprintf('\n');
end



% Save detection results using detection results
savePath = fullfile(DATAopts.resdir, 'Main', ['%s_det_', set, '_%s.txt']);
resultsName = sprintf(savePath, compName, [objName '-' partName]);
fid = fopen(resultsName,'w');
for j=1:length(boxIms)
fprintf(fid,'%s %f %f %f %f %f\n', boxIms{j}, boxClfs(j),boxes(j,:));
end
fclose(fid);
fprintf('\n');

if doEval
[recall, prec, ap] = VOCevaldetParts_modified(DATAopts, partName, objName, resultsName, false);
apUpperBound = max(recall);
else
recall = 0;
prec = 0;
ap = 0;
apUpperBound = 0;
end
@@ -0,0 +1,141 @@
function [rec,prec,ap] = VOCevaldetParts_modified(DATAopts, cls, obj, loadName, draw, flipBoxes)

% load test set
tic;

gtids = GetImagesPlusLabels(DATAopts.testset);

for i=1:length(gtids)
% display progress
if toc>1
fprintf('%s: pr: load: %d/%d\n',cls,i,length(gtids));
drawnow;
tic;
end

% % Create annotation struct as if it was being read from file
recs(i).objects.class = [];
recs(i).objects.bbox = [];
recs(i).objects.difficult = [];
%
% add parts belonging to object class
objIm = DATAopts.imdbTest.parts{i};
idxObj = 1;
for kk = 1:size(objIm)
if strcmp(DATAopts.imdbTest.objects{i}.class(kk), obj) && ~isempty(objIm{kk})
for ll = 1:size(objIm{kk}.class_id,1)
recs(i).objects(idxObj).class = DATAopts.imdbTest.prt_classes{DATAopts.imdbTest.obj_class2id(obj)}{objIm{kk}.class_id(ll)};
recs(i).objects(idxObj).bbox = objIm{kk}.bbox(ll,:);
recs(i).objects(idxObj).difficult = objIm{kk}.difficult(ll);
idxObj = idxObj + 1;
end
end
end
end


fprintf('%s: pr: evaluating detections\n',cls);

% hash image ids
hash=VOChash_init_modified(gtids);

% extract ground truth objects

npos=0;
gt(length(gtids))=struct('BB',[],'diff',[],'det',[]);
for i=1:length(gtids)
% extract parts of class
clsinds=strcmp(cls,{recs(i).objects.class});
gt(i).BB=cat(1,recs(i).objects(clsinds).bbox)';
gt(i).diff=[recs(i).objects(clsinds).difficult];
gt(i).det=false(length(clsinds),1);
npos=npos+sum(~gt(i).diff);
end

% load results
% [ids,confidence,b1,b2,b3,b4]=textread(sprintf(DATAopts.detrespath,id,cls),'%s %f %f %f %f %f');
[ids,confidence,b1,b2,b3,b4]=textread(loadName,'%s %f %f %f %f %f');

if exist('flipBoxes', 'var') && flipBoxes == true
BB=[b2 b1 b4 b3]';
else
BB=[b1 b2 b3 b4]';
end

% sort detections by decreasing confidence
[sc,si]=sort(-confidence);
ids=ids(si);
BB=BB(:,si);

% assign detections to ground truth objects
nd=length(confidence);
tp=zeros(nd,1);
fp=zeros(nd,1);
tic;
for d=1:nd
% display progress
if toc>1
fprintf('%s: pr: compute: %d/%d\n',cls,d,nd);
drawnow;
tic;
end

% find ground truth image
i=VOChash_lookup_modified(hash,ids{d});
if isempty(i)
error('unrecognized image "%s"',ids{d});
elseif length(i)>1
error('multiple image "%s"',ids{d});
end

% assign detection to ground truth object if any
bb=BB(:,d);
ovmax=-inf;
for j=1:size(gt(i).BB,2)
bbgt=gt(i).BB(:,j);
bi=[max(bb(1),bbgt(1)) ; max(bb(2),bbgt(2)) ; min(bb(3),bbgt(3)) ; min(bb(4),bbgt(4))];
iw=bi(3)-bi(1)+1;
ih=bi(4)-bi(2)+1;
if iw>0 & ih>0
% compute overlap as area of intersection / area of union
ua=(bb(3)-bb(1)+1)*(bb(4)-bb(2)+1)+...
(bbgt(3)-bbgt(1)+1)*(bbgt(4)-bbgt(2)+1)-...
iw*ih;
ov=iw*ih/ua;
if ov>ovmax
ovmax=ov;
jmax=j;
end
end
end
% assign detection as true positive/don't care/false positive
if ovmax>=DATAopts.minoverlap
if ~gt(i).diff(jmax)
if ~gt(i).det(jmax)
tp(d)=1; % true positive
gt(i).det(jmax)=true;
else
fp(d)=1; % false positive (multiple detection)
end
end
else
fp(d)=1; % false positive
end
end

% compute precision/recall
fp=cumsum(fp);
tp=cumsum(tp);
rec=tp/npos;
prec=tp./(fp+tp);

ap=VOCap(rec,prec);

if draw
% plot precision/recall
plot(rec,prec,'-');
grid;
xlabel 'recall'
ylabel 'precision'
title(sprintf('class: %s, subset: %s, AP = %.3f',cls,DATAopts.testset,ap));
end
83 changes: 83 additions & 0 deletions matconvnet-calvin/examples/parts/calvinNNPartDetection.m
@@ -0,0 +1,83 @@
% function calvinNNDetection()
%
% Copyright by Holger Caesar, 2016

% Global variables
global glDatasetFolder glFeaturesFolder;
assert(~isempty(glDatasetFolder) && ~isempty(glFeaturesFolder));

%%% Settings
% Dataset
vocYear = 2010;
trainName = 'train';
testName = 'val';
vocName = sprintf('VOC%d', vocYear);
datasetDir = [fullfile(glDatasetFolder, vocName), '/'];


% Specify paths
outputFolder = fullfile(glFeaturesFolder, 'CNN-Models', 'Parts', vocName, sprintf('%s-testRelease', vocName));
netPath = fullfile(glFeaturesFolder, 'CNN-Models', 'matconvnet', 'imagenet-caffe-alex.mat');
logFilePath = fullfile(outputFolder, 'log.txt');

% Fix randomness
randSeed = 42;
rng(randSeed);

% Setup dataset specific options and check validity
setupDataOptsPrts(vocYear, testName, datasetDir);
global DATAopts; % Database specific paths
assert(~isempty(DATAopts), 'Error: Dataset not initialized properly!');


% Task-specific
nnOpts.testFn = @testPartDetection;
nnOpts.misc.overlapNms = 0.3;
% Objectives for both parts and objects
nnOpts.derOutputs = {'objectivePrt', 1, 'objectiveObj', 1, 'regressObjectivePrt', 1, 'regressObjectiveObj', 1};

% General
nnOpts.batchSize = 2;
nnOpts.numSubBatches = nnOpts.batchSize; % 1 image per sub-batch
nnOpts.weightDecay = 5e-4;
nnOpts.momentum = 0.9;
nnOpts.numEpochs = 16;
nnOpts.learningRate = [repmat(1e-3, 12, 1); repmat(1e-4, 4, 1)];
nnOpts.misc.netPath = netPath;
nnOpts.expDir = outputFolder;
nnOpts.convertToTrain = 0; % perform explicit conversion to our architecure
nnOpts.fastRcnn = 0;
nnOpts.bboxRegress = 1;
nnOpts.gpus = []; % for automatic selection use: SelectIdleGpu();

% Create outputFolder
if ~exist(outputFolder, 'dir')
mkdir(outputFolder);
end

% Start logging
diary(logFilePath);

%%% Setup
% Start from pretrained network
net = load(nnOpts.misc.netPath);

% Setup imdb
imdb = setupImdbPartDetection(trainName, testName, net);

% Create calvinNN CNN class
% Do not transform into fast-rcnn with bbox regression
calvinn = CalvinNN(net, imdb, nnOpts);

% Perform here the conversion to part/obj architecture
calvinn.convertNetworkToPrtObjFastRcnn;

%%% Train
calvinn.train();

%%% Test
stats = calvinn.testPrtObj();

% TEST EVAL CODE WITH LOADED STATS
%%% Eval
evalPartAndObjectDetection(testName, stats, nnOpts);

0 comments on commit 2843e5d

Please sign in to comment.