-
Notifications
You must be signed in to change notification settings - Fork 1
/
SparseCoding2.m
202 lines (186 loc) · 10.7 KB
/
SparseCoding2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
classdef SparseCoding2 < handle
properties
nBasis; % total number of basis
nBasisUsed; % number of basis used to encode images in sparse mode
basisSize; % size of each (binocular) base vector: patchSize * patchSize * 2 (left + right eye)
eta; % learning rate
temperature; % temperature in softmax
basis; % all basis functions
basisHist; % basis functions history
currentCoef; % current coeffient matrix
currentError; % current reconstruction error
sizeBatch; % image batch size's 2nd dimension
selectedBasis; % indicates for each basis how often it has been selected
BFinit; % describes the initialization of basis functions
BFfitFreq; % fit frequencies to BFs or wavelengths?
BFIdent; % helps identifying the scale inside the scarse coder
end
methods
% Constructor
% PARAM = [nBasis, nBasisUsed, basisSize, eta, temperature, sizeBatch]
function obj = SparseCoding2(PARAM)
obj.nBasis = PARAM(1);
obj.selectedBasis = zeros(PARAM(1),1);
obj.nBasisUsed = PARAM(2);
obj.basisSize = PARAM(3);
obj.eta = PARAM(4);
obj.temperature = PARAM(5);
obj.BFinit = PARAM(6);
obj.BFfitFreq = PARAM(7);
obj.BFIdent = PARAM(8);
obj.sizeBatch = PARAM(9); % is always in the last entry of the inout array
if obj.BFinit == 1 % white noise BF init
obj.basis = rand(obj.basisSize, obj.nBasis) - 0.5;
obj.basis = obj.basis * diag(1 ./ sqrt(sum(obj.basis .* obj.basis)));
tmpNorm = ones(obj.basisSize, 1) * sqrt(sum(obj.basis .* obj.basis, 1));
obj.basis = obj.basis ./ tmpNorm;
elseif obj.BFinit == 2 % non-aligned Gabor wavelets
obj.basis = BaseGenerator(0, 0, sqrt(PARAM(3)/2), PARAM(1));
elseif obj.BFinit == 3 % monocular Gabor wavelets
obj.basis = BaseGenerator(0, 0, sqrt(PARAM(3)/2), PARAM(1));
x = randperm(obj.nBasis);
for b = 1 : obj.nBasis/2
obj.basis(1:end/2, x(b)) = 0; % monocular right
obj.basis(end/2+1:end, x(b+(obj.nBasis/2))) = 0; % monocular left
end
obj.basis = bsxfun(@rdivide,obj.basis, sqrt(sum(obj.basis .^ 2)));
elseif obj.BFinit == 4 % take preloaded BFs
% fixed 3 deg strabism
% model = load('/home/aecgroup/aecdata/Results/eLifePaper/strabism/19-02-18_500000iter_2_AllfixAt6m_filtB_29_strabAngle_3_seed2/model.mat');
% laplacian 3 deg strabism
% model = load('/home/aecgroup/aecdata/Results/eLifePaper/inducedStrabism/20-09-22_500000iter_2_inducedStrab_3deg_lapSig02_od05-1m/model.mat');
% monocular deprivation
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-04_500000iter_1_fsize6std_filtBoth_45_prob_1_seed1/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 41 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-04_500000iter_1_fsize6std_filtBoth_45_prob_1_seed1/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 42 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-18_500000iter_2_fsize6std_filtBoth_45_prob_1_seed2/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 43 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-19_500000iter_3_fsize6std_filtBoth_45_prob_1_seed3/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 44 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/18-10-18_500000iter_4_fsize6std_filtBoth_45_prob_1_seed4/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
elseif obj.BFinit == 45 % take preloaded BFs
model = load('/home/aecgroup/aecdata/Results/eLifePaper/explFilterSizes/19-06-05_500000iter_5_fsize6std_filtBoth_45_prob_1_seed5/model.mat');
model = model.model;
obj.basis = model.scModel{obj.BFIdent}.basis;
else
error('Unrecognized basis function initialization')
end
obj.basisHist = [];
obj.currentCoef = zeros(obj.nBasis, obj.sizeBatch); %288x81
obj.currentError = zeros(obj.basisSize, obj.sizeBatch); %128x81
%TODO maybe reimplement reloading basis functions
end
%%% Encode the input images accoring to softmax distribution
% @param imageBatch: input image patches batch
function softmaxEncode(this, imageBatch)
this.currentCoef = this.currentCoef * 0; % needs to be tested if resetting is necessary
tmp = imageBatch;
for count = 1 : this.nBasisUsed
corrl = abs(this.basis' * tmp) / this.temperature;
corrl = corrl - kron(ones(this.nBasis, 1), max(corrl));
softmaxcorr = softmax(corrl);
softmaxcorr = tril(ones(this.nBasis)) * softmaxcorr - kron(ones(this.nBasis, 1), rand(1, this.sizeBatch));
softmaxcorr(softmaxcorr < 0) = 2;
[~, index] = min(softmaxcorr);
corrl = this.basis' * tmp;
linearIndex = sub2ind(size(corrl), index, 1 : this.sizeBatch);
this.currentCoef(linearIndex) = this.currentCoef(linearIndex) + corrl(linearIndex);
tmp = imageBatch - this.basis * this.currentCoef;
end
this.currentError = tmp;
end
%%% Encode the input images with the best matched basis
% @param imageBatch: input image patches batch
function sparseEncode(this, imageBatch)
% batch_size = size(Images, 2);
% Coef = zeros(this.nBasis, batch_size);
% I = Images;
% for (count = 1:this.nBasisUsed)
% corrl = this.Basis'*I;
% [~, index] = max(abs(corrl));
% alpha = diag(this.Basis(:, index)'*I);
% linearIndex = sub2ind(size(corrl), index, 1:batch_size);
% Coef(linearIndex) = Coef(linearIndex) + alpha';
% I = Images - this.Basis*Coef;
% end
% Error = I;
this.currentCoef = this.currentCoef * 0; % needs to be tested if resetting is necessary
corrl = this.basis' * imageBatch; % correlation of each basis with each patch
corrBB = this.basis' * this.basis; % correlation between basis
for count = 1 : this.nBasisUsed
[~, index] = max(abs(corrl)); % indices of bases with max correlation per patch
linearIndex = sub2ind(size(corrl), index, 1 : this.sizeBatch); % corresponding linear indicies in correlation matrix
pCorr = corrl(linearIndex); % vector of correlations per patch (coefs per patch)
this.currentCoef(linearIndex) = this.currentCoef(linearIndex) + pCorr; % calculate new correlation coefficients
corrl = corrl - bsxfun(@times, corrBB(:, index), pCorr); % (see Yu's doc)
end
this.currentError = imageBatch - this.basis * this.currentCoef; % 128x81 = 128x81 - 128x288 * 288x81
end
%%% Calculate the correlation between input image and the basis
% @param imageBatch: input image patches batch
function fullEncode(this, imageBatch)
this.currentCoef = this.basis' * imageBatch;
this.currentError = imageBatch - this.basis * this.currentCoef;
end
%%% Update basis functions
function stepTrain(this)
deltaBases = this.currentError * this.currentCoef' / size(this.currentError, 2);
this.basis = this.basis + this.eta * deltaBases;
this.basis = bsxfun(@rdivide, this.basis, sqrt(sum(this.basis .^ 2)));
% also update the selected basis functions
usedBasis = zeros(size(this.currentCoef));
usedBasis(find(this.currentCoef)) = 1;
usedBasis = sum(usedBasis, 2);
this.selectedBasis = this.selectedBasis + (usedBasis ./ sum(usedBasis));
this.selectedBasis = this.selectedBasis ./ sum(this.selectedBasis);
end
%%% Track the evolution of all basis functions over time
function saveBasis(this)
this.basisHist = cat(3, this.basisHist, this.basis);
end
%%% Display the Basis functions (Zhao Yu code) at iteration t
function displayBasis(this, t)
%how to arrange the basis (rows, col)
R = 16;
C = 18;
len = 1;
% basisTrack = this.drecord.basisTrack(1:len);
basisTrack{1} = this.basis;
%checkPoint = 1;
endBasis = basisTrack{end}(1 : end / 2, :);
leftEnergy = abs(sum(endBasis .^ 2) - 0.5);
[~, I] = sort(leftEnergy);
% h = gcf;
% set(h,'Position',[1 1 800 600]);
% scrsz = get(0,'ScreenSize');
% set(h,'Position',[scrsz(1) scrsz(2) scrsz(3) scrsz(4)]);
subplot(1, 1, 1);
[di, num] = size(basisTrack{1});
fun1 = @(blc_struct) padarray(padarray(reshape(permute(padarray(reshape(blc_struct.data, sqrt(di / 2), ...
sqrt(di / 2), 2), [1, 1], 'pre'), [1, 3, 2]), (sqrt(di / 2) + 1) * 2, sqrt(di / 2) + 1), ...
[1, 1], 'post') - 1, [1 1], 'pre') + 1;
for j = 1 : len
A = basisTrack{j}(:, I);
% B = reshape(A, di*sqrt(num/2), sqrt(num/2)*2);
B = reshape(A, di * R, C);
B = B / max(max(abs(B))) + 0.5;
C = padarray(padarray(blockproc(B, [di, 1], fun1) - 1, [1 1],'post') + 1,[2, 2]);
imshow(C);
% title(num2str(checkPoint(j)));
title(num2str(t));
drawnow;
end
end
end
end