Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

tried carrying around images instead of patches, but for loop at comp…

…uteFeature is too slow in matlab
  • Loading branch information...
commit 79d9649b05c5fa6f0302995099d207a5f5fc388c 1 parent ad23f0c
@akanazawa authored
View
23 DecisionTree.m
@@ -121,7 +121,7 @@ function normalizeAll(DT)
end
function node = computeDepthFirstByImages(DT, node, data, Is, depth)
- if isempty(labels), node=[];, end
+ if isempty(data), node=[];, end
if depth == DT.maxDepth || numel(unique([data.label]))==0
classDist = zeros(DT.numClass, 1); % will fill this out later
node = TreeNode(classDist, DT.numNodes, depth);
@@ -135,7 +135,8 @@ function normalizeAll(DT)
DT.numNodes);
for i = 1:DT.numFeature
method = mod(i, numFactory)+1; %DT.factory{mod(i, numFactory)+1};
- [values, decider] = computeFeature(data, Is, method);
+ keyboard
+ [values, decider] = computeFeatureByImages(data, Is, method);
[score, threshold] = DT.computeBestThreshold(values, [data.label]);
if score > bestScore
bestScore = score;
@@ -150,11 +151,11 @@ function normalizeAll(DT)
DT.numNodes = DT.numNodes + 1;
node.decider = bestDecider;
% do the split
- [values, ~] = computeFeature(data, bestDecider);
+ [values, ~] = computeFeatureByImages(data, bestDecider);
toLeft = values < bestDecider.threshold;
- node.left = DT.computeDepthFirst(TreeNode(), ...
+ node.left = DT.computeDepthFirstByImages(TreeNode(), ...
data(toLeft), Ids, depth+1);
- node.right = DT.computeDepthFirst(TreeNode(), ...
+ node.right = DT.computeDepthFirstByImages(TreeNode(), ...
data(~toLeft), Ids, depth+1);
end
@@ -209,10 +210,10 @@ function fillByImages(DT, node, data, Is)
node.distribution = node.distribution + ...
hist([data.label], 1:DT.numClass)'.*DT.labelWeights;
if ~node.isLeaf
- [values, ~] = computeFeature(data, Is, node.decider);
+ [values, ~] = computeFeatureByImages(data, Is, node.decider);
toLeft = values < node.decider.threshold;
- if sum(toLeft)~=0, DT.fill(node.left, data(toLeft), Is); end
- if sum(~toLeft)~=0 DT.fill(node.right, data(~toLeft), Is); end
+ if sum(toLeft)~=0, DT.fillByImages(node.left, data(toLeft), Is); end
+ if sum(~toLeft)~=0 DT.fillByImages(node.right, data(~toLeft), Is); end
end
end
@@ -262,15 +263,15 @@ function normalize(DT, node)
dist(:, ids) = repmat(node.distribution, [1, length(ids)]);
return
end
- [values, ~] = computeFeature(data, node.decider);
+ [values, ~] = computeFeatureByImages(data, node.decider);
toLeft = values < node.decider.threshold;
if sum(toLeft) ~= 0
- [dist, bost] = DT.findLeafDist(node.left, ...
+ [dist, bost] = DT.findLeafDistByImages(node.left, ...
data(toLeft), ...
Is, dist, ids(toLeft), bost);
end
if sum(~toLeft)~= 0
- [dist, bost] = DT.findLeafDist(node.right, ...
+ [dist, bost] = DT.findLeafDistByImages(node.right, ...
data(~toLeft), ...
Is, dist, ids(~toLeft), bost);
end
View
29 computeFeature.m
@@ -0,0 +1,29 @@
+function [val, feat] = computeFeature(patches, param)
+% returns pixel combination within a square patch
+% sums the sum of two channels in this patch
+% patches is a d x d x N x 3 matrix
+if ~isstruct(param)
+ d = size(patches, 1); % boxSize
+ feat.rows = randi(d, 2, 1);
+ feat.cols = randi(d, 2, 1);
+ feat.channels = randi(3, 2, 1);
+ feat.method = param;
+else
+ feat = param;
+end
+A = patches(feat.rows(1), feat.cols(1), :, feat.channels(1));
+B = patches(feat.rows(2), feat.cols(2), :, feat.channels(2));
+
+switch feat.method
+ case 1% 'unary'
+ val = A;
+ case 2% 'subAbs'
+ val = abs(A - B);
+ case 3% 'addTwo'
+ val = A + B;
+ case 4%'sub'
+ val = A - B;
+end
+
+val = double(val(:));
+
View
7 config.m
@@ -1,7 +1,6 @@
%%%%%%%%%%%%%%%%%%%%
% CONFIGURATION file for STF
%%%%%%%%%%%%%%%%%%%%
-addpath('util');
% directory settings
DIR.dataset ='/Users/kanazawa/Documents/projects/datasets/MSRC21/';
DIR.images = fullfile(DIR.dataset, 'Images');
@@ -11,12 +10,12 @@
PATH.trainingNames = fullfile(DIR.dataset, 'trainval.txt');
PATH.testNames = fullfile(DIR.dataset, 'test.txt');
PATH.trainingPatches = fullfile(DIR.result, 'trainingPatches.mat');
-PATH.trainingPointsSub = fullfile(DIR.result, 'trainingPointsSub.mat');
+%PATH.trainingPointsSub = fullfile(DIR.result, 'trainingPointsSub.mat');
PATH.labelWeights = fullfile(DIR.result, 'labelWeights.mat');
PATH.forestSkeleton = fullfile(DIR.result, 'forestSkeleton.mat');
PATH.forestFilled = fullfile(DIR.result, 'forestFilled.mat');
-PATH.forestSkeletonByImages = fullfile(DIR.result, 'forestSkeletonByImages.mat');
-PATH.forestFilledByImages = fullfile(DIR.result, 'forestFilledByImages.mat');
+% PATH.forestSkeletonByImages = fullfile(DIR.result, 'forestSkeletonByImages.mat');
+% PATH.forestFilledByImages = fullfile(DIR.result, 'forestFilledByImages.mat');
% patch sampling parameters
BOX.sampleFreq = 4; % space between sampled patches
View
22 do_test.m
@@ -18,9 +18,12 @@ function do_test(config_file)
imageNames = textscan(fid, '%s');
imageNames = imageNames{1};
fclose(fid);
+labelNames = strcat([DIR.groundTruth, '/'], regexprep(imageNames, '\.bmp$', '_GT.bmp'));
+imageNamesFull = strcat([DIR.images, '/'], imageNames);
+
numTest = numel(imageNames);
wait = waitbar(0, 'testing');
-
+fprintf('start testing\n');
for i = 1:numTest
patches = getPatches(imageNames{i}, DIR, [], BOX, []);
dist = zeros(numClass, size(patches, 3), FOREST.numTree);
@@ -32,15 +35,22 @@ function do_test(config_file)
% normalize
distAll = bsxfun(@rdivide, distAll+(1e-4./numClass), sum(distAll)+1e-4);
[~, pred] = max(distAll, [], 1);
- I = imread(fullfile(DIR.images, imageNames{i}));
+ I = imread(imageNamesFull{i});
+ L = imread(labelNames{i});
[r, c, ~] = size(I);
pred = reshape(pred, r, c);
predRGB = label2rgb(pred, LABELS./255);
- h=figure(1); imagesc(I), hold on;
+ h=figure(1); subplot(131); imagesc(I); hold on;
himage = imagesc(predRGB);
- set(himage, 'AlphaData', 0.4);
- print(h, fullfile(DIR.result, imageNames{i}))
- % imwrite(fullfile(DIR.result, imageNames{i}), 'bmp');
+ set(himage, 'AlphaData', 0.4);
+ axis off image; title('overlay');
+ subplot(132); imagesc(predRGB); axis off image;
+ title('prediction');
+ subplot(133); imagesc(L); axis off image;
+ title('ground truth');
+ % print(h, fullfile(DIR.result, imageNames{i}))
+ keyboard
+ imwrite(predRGB, fullfile(DIR.result, imageNames{i}), 'bmp');
wait = waitbar(i/numTest, wait, sprintf(['done evaluating test ' ...
'image %d'], i));
end
View
70 do_train.m
@@ -10,14 +10,14 @@ function do_train(config_file)
DEBUG = 0;
eval(config_file); % load settings
%% learn the splits
-if ~exist(PATH.forestSkeletonByImages, 'file')
+if ~exist(PATH.forestSkeleton, 'file')
% % make training patches
- % if ~exist(PATH.trainingPatches, 'file')
- % data = sampleTrainingImages(config_file);
- % else, load(PATH.trainingPatches); end
- if ~exist(PATH.trainingPointsSub, 'file')
- [data, Is] = sampleTrainingImagesByImages(config_file);
- else, load(PATH.trainingPointsSub); end
+ if ~exist(PATH.trainingPatches, 'file')
+ data = sampleTrainingImages(config_file);
+ else, load(PATH.trainingPatches); end
+ % if ~exist(PATH.trainingPointsSub, 'file')
+ % [data, Is] = sampleTrainingImagesByImages(config_file);
+ % else, load(PATH.trainingPointsSub); end
% make label weights
if ~exist(PATH.labelWeights, 'file')
@@ -33,25 +33,25 @@ function do_train(config_file)
% randomly pick dataPerTree amount of training data for
% each tree
subData = data(rand(numel(data), 1) < FOREST.dataPerTree);
- % patches = [subData.patch];
- % % make it d by d by N by 3
- % patches = reshape(patches, size(patches, 1), ...
- % size(patches, 1), length(subData), 3);
- % labels = double([subData.label]);
- % tree.trainDepthFirst(patches, labels);
- tree.trainDepthFirstByImages(subData, Is);
+ patches = [subData.patch];
+ % make it d by d by N by 3
+ patches = reshape(patches, size(patches, 1), ...
+ size(patches, 1), length(subData), 3);
+ labels = double([subData.label]);
+ tree.trainDepthFirst(patches, labels);
+ % tree.trainDepthFirstByImages(subData, Is);
forest(i) = tree;
wait = waitbar(i/FOREST.numTree, wait, sprintf(['finished learning ' ...
'tree: %d'], i));
end
close(wait);
- save(PATH.forestSkeletonByImages, 'forest');
+ save(PATH.forestSkeleton, 'forest');
else
- load(PATH.forestSkeletonByImages);
+ load(PATH.forestSkeleton);
end
%% fill the forest
-if ~exist(PATH.forestFilledByImages, 'file');
+if ~exist(PATH.forestFilled, 'file');
fprintf('fill the forest\n');
fid = fopen(PATH.trainingNames, 'r');
imageNames = textscan(fid, '%s');
@@ -60,36 +60,36 @@ function do_train(config_file)
numTrain = numel(imageNames);
wait = waitbar(0, 'filling the tree');
- % for i = 1:numTrain
- % data = getPatches(imageNames{i}, DIR, LABELS, BOX, TRANSFORM);
- % if ~isempty(data)
- % patches = [data.patch];
- % % make it d by d by N by 3
- % patches = reshape(patches, size(patches, 1), ...
- % size(patches, 1), numel(data), 3);
- % labels = double([data.label]);
- % for t = 1:FOREST.numTree
- % forest(t).fillAll(patches, labels);
- % end
- % end
- % wait = waitbar(i/numTrain, wait, sprintf(['filling training ' ...
- % 'image: %d'], i));
- % end
for i = 1:numTrain
- [data, Is] = getPatchesByImages(imageNames{i}, DIR, LABELS, BOX, TRANSFORM);
+ data = getPatches(imageNames{i}, DIR, LABELS, BOX, TRANSFORM);
if ~isempty(data)
+ patches = [data.patch];
+ % make it d by d by N by 3
+ patches = reshape(patches, size(patches, 1), ...
+ size(patches, 1), numel(data), 3);
+ labels = double([data.label]);
for t = 1:FOREST.numTree
- forest(t).fillAllByImages(data, Is);
+ forest(t).fillAll(patches, labels);
end
end
wait = waitbar(i/numTrain, wait, sprintf(['filling training ' ...
'image: %d'], i));
end
+ % for i = 1:numTrain
+ % [data, Is] = getPatchesByImages(imageNames{i}, DIR, LABELS, BOX, TRANSFORM);
+ % if ~isempty(data)
+ % for t = 1:FOREST.numTree
+ % forest(t).fillAll(data, Is);
+ % end
+ % end
+ % wait = waitbar(i/numTrain, wait, sprintf(['filling training ' ...
+ % 'image: %d'], i));
+ % end
fprintf('normalize tree\n');
for t = 1:FOREST.numTree
forest(t).normalizeAll();
end
- save(PATH.forestFilledByImages, 'forest');
+ save(PATH.forestFilled, 'forest');
close(wait);
end
View
1  extractBost.m
@@ -34,6 +34,5 @@
assert(all(bost(start:endInd) == 0));
bost(start:endInd) = hist;
prior = prior + sum(dists, 2)./R;
- keyboard
end
prior = prior./numTree;
View
32 resultsByImage/computeFeatureByImages.m
@@ -0,0 +1,32 @@
+%%IF patches weren't given as d x d x 3, but just the center point
+%%and the id to it's corresponding image I
+function [val, feat] = computeFeatureByImages(data, Is, param)
+if ~isstruct(param)
+ d = (15-1)/2; % for the moment, later move it to DT's constant and
+ % have this entire function inside DecisionTree.m
+ feat.rows = randi(2*d, 2, 1) - d;
+ feat.cols = randi(2*d, 2, 1) - d;
+ feat.channels = randi(3, 2, 1);
+ feat.method = param;
+else
+ feat = param;
+end
+
+A = zeros(length(data), 1);
+B = zeros(length(data), 1);
+for i = 1:length(data)
+ A(i) = Is{data(i).imageId}(data(i).row+feat.rows(1), data(i).col+feat.cols(1), feat.channels(1));
+ B(i) = Is{data(i).imageId}(data(i).row+feat.rows(2), data(i).col+feat.cols(2), feat.channels(2));
+end
+
+switch feat.method
+ case 1% 'unary'
+ val = A;
+ case 2% 'subAbs'
+ val = abs(A - B);
+ case 3% 'addTwo'
+ val = A + B;
+ case 4%'sub'
+ val = A - B;
+end
+
View
76 resultsByImage/getPatchesByImages.m
@@ -0,0 +1,76 @@
+function [data, Is] = getPatchesByImages(fname, DIR, LABELS, BOX, TRANSFORM)
+% No subsampling, collect patches from all pixels. Edge cases are
+% treated by replication
+if ~isempty(TRANSFORM)
+ Is = cell(TRANSFORM.numTransform+1, 1);
+ DEBUG = 0;
+ k = 1;
+ data = struct([]);
+ rad = (BOX.size-1)/2; % of patch
+ I = imread(fullfile(DIR.images, fname));
+ L = imread(fullfile(DIR.groundTruth, regexprep(fname, '\.bmp$', '_GT.bmp')));
+ Ipad = padarray(I, [rad, rad], 'symmetric');
+ [r, c, ~] = size(Ipad);
+ Ilab = applycform(Ipad, BOX.cform);
+ Is{1} = Ilab;
+ [ri, ci] = ndgrid(1:r, 1:c);
+ ri = ri(rad+1:r-rad,rad+1:c-rad);
+ ci = ci(rad+1:r-rad,rad+1:c-rad);
+ %% collect patches without transformation
+ for j = 1:numel(ri)
+ % need -rad because L wasn't padded
+ gt = find(L(ri(j)-rad, ci(j)-rad, 1) == LABELS(:, 1) & ...
+ L(ri(j)-rad, ci(j)-rad, 2) == LABELS(:, 2) & ...
+ L(ri(j)-rad, ci(j)-rad, 3) == LABELS(:, 3) );
+ if ~isempty(gt)
+ data(k).label = gt;
+ data(k).imageId = 1;
+ data(k).row = ri(j);
+ data(k).col = ci(j);
+ k = k + 1;
+ end
+ end
+ % if no appropriate label is found no need to do transformation
+ if isempty(data), return; end
+ %% collect patches after transformation
+ for t = 1:TRANSFORM.numTransform
+ [I2, L2] = transformImage(I, L, TRANSFORM);
+ % turn it to lab
+ Ilab2 = applycform(I2, BOX.cform);
+ Ilab2 = padarray(Ilab2, [rad, rad], 'symmetric');
+ Is{t+1} = Ilab2
+ [r, c, ~] = size(Ilab2);
+ [ri, ci] = ndgrid(1:r, 1:c);
+ ri = ri(rad+1:r-rad,rad+1:c-rad);
+ ci = ci(rad+1:r-rad,rad+1:c-rad);
+ for j=1:numel(ri)
+ gt = find(L2(ri(j)-rad, ci(j)-rad, 1) == LABELS(:, 1) & ...
+ L2(ri(j)-rad, ci(j)-rad, 2) == LABELS(:, 2) & ...
+ L2(ri(j)-rad, ci(j)-rad, 3) == LABELS(:, 3) );
+ if ~isempty(gt)
+ data(k).label = gt;
+ data(k).imageId = t+1;
+ data(k).row = ri(j);
+ data(k).col = ci(j);
+ k = k + 1;
+ end
+ end
+ end
+else %% get patches for testing
+ rad = (BOX.size-1)/2; % of patch
+ I = imread(fullfile(DIR.images, fname));
+ Ipad = padarray(I, [rad, rad], 'symmetric');
+ [r, c, ~] = size(Ipad);
+ Ilab = applycform(Ipad, BOX.cform);
+ Is = {Ilab};
+ [ri, ci] = ndgrid(1:r, 1:c);
+ ri = ri(rad+1:r-rad,rad+1:c-rad);
+ ci = ci(rad+1:r-rad,rad+1:c-rad);
+ data = struct([]);
+ for i = 1:numel(ri)
+ data(i).imageId = 1;
+ data(i).row = ri(j);
+ data(i).col = ci(j);
+ end
+end
+
View
49 resultsByImage/sampleTrainingImagesByImages.m
@@ -0,0 +1,49 @@
+function [data, Is] = sampleTrainingImagesByImages(config_file)
+% subsamples the image by BOX.sampleFreq and for each subsampled
+% pixel, if the label is valid, stores its label, row, col, and
+% it's image ID (the image this pixel belongs to) in a struct
+eval(config_file);
+
+fid = fopen(PATH.trainingNames, 'r');
+imageNames = textscan(fid, '%s');
+imageNames = imageNames{1};
+fclose(fid);
+
+labelNames = strcat([DIR.groundTruth, '/'], regexprep(imageNames, '\.bmp$', '_GT.bmp'));
+imageNames = strcat([DIR.images, '/'], imageNames);
+numTrain = numel(imageNames);
+Is = cell(numTrain, 1);
+%% select patch centers from each training images
+if ~exist(PATH.trainingPatches, 'file')
+ rad = (BOX.size-1)/2; % of patch
+ data = struct([]);
+ k = 1;
+ wait = waitbar(0, 'preprocessing data');
+ for i = 1:numTrain
+ I = imread(imageNames{i});
+ L = imread(labelNames{i});
+ Ilab = applycform(I, BOX.cform);
+ Is{i} = Ilab;
+ [r, c, ~] = size(I);
+ [ri, ci] = ndgrid(1:r, 1:c);
+ % subsampled center pixels of patches
+ ri = ri(rad+1:BOX.sampleFreq:r-rad,rad+1:BOX.sampleFreq:c-rad);
+ ci = ci(rad+1:BOX.sampleFreq:r-rad,rad+1:BOX.sampleFreq:c-rad);
+ for j = 1:numel(ri)
+ gt = find(L(ri(j), ci(j), 1) == LABELS(:, 1) & ...
+ L(ri(j), ci(j), 2) == LABELS(:, 2) & ...
+ L(ri(j), ci(j), 3) == LABELS(:, 3) );
+ if ~isempty(gt)
+ data(k).label = gt;
+ data(k).imageId = i;
+ data(k).row = ri(j);
+ data(k).col = ci(j);
+ k = k + 1;
+ end
+ end
+ wait = waitbar(i/numTrain, wait, sprintf(['preprocessing training ' ...
+ 'image: %d'], i));
+ end
+ close(wait);
+ save(PATH.trainingPointsSub, 'data', 'Is');
+end
Please sign in to comment.
Something went wrong with that request. Please try again.