diff --git a/examples/exampleL2Inf.m b/examples/exampleL2Inf.m index e3147ad..d25aa45 100644 --- a/examples/exampleL2Inf.m +++ b/examples/exampleL2Inf.m @@ -3,7 +3,7 @@ %% read data addpath(genpath('..')); -dimension = 50000; +dimension = 500000; weightL2 = 3; weightL2Inf = 0.1; dataPart = [5*rand(1,dimension);5*rand(1,dimension)]; diff --git a/prox/L2InfProxDual.m b/prox/L2InfProxDual.m index fae3945..60ce9d8 100644 --- a/prox/L2InfProxDual.m +++ b/prox/L2InfProxDual.m @@ -13,82 +13,31 @@ function applyProx(obj,main,dualNumbers,~) end yTildeNorm = sqrt(yTildeNorm); - - h = @(lambda) (yTildeNorm > lambda) .* (yTildeNorm - lambda); - g = @(lambda) sum(h(lambda)) - obj.factor; - %h = @(lambda) (sortyTildeNorm > lambda) .* (sortyTildeNorm - lambda); - %g = @(lambda) sum(h(lambda)) - obj.factor; - %lambda = max(0,fzero(g, max(yTildeNorm))); - sortyTildeNorm = sort(yTildeNorm, 'descend'); - for findIndex=1:length(sortyTildeNorm) - if g(sortyTildeNorm(findIndex)) > 0 - break + + yTildeSum = 0; + g2 = -obj.factor; + lambda = 0; + for index=2:length(sortyTildeNorm) + lambda = sortyTildeNorm(index); + yTildeSum = yTildeSum + sortyTildeNorm(index - 1); + g2 = yTildeSum - (index - 1) * lambda - obj.factor; + + if g2 >= 0 + break; end end - - lambda = 0; - if ~isempty(findIndex) && findIndex ~= 1 && (findIndex ~= length(sortyTildeNorm) || g(sortyTildeNorm(findIndex)) > 0) - lambda = (sum(sortyTildeNorm(1:findIndex-1)) - obj.factor) / (findIndex - 1); + + if g2 < 0 + lambda = 0; + else + lambda = (yTildeSum - obj.factor) / (index - 1); end - - %disp(findIndex) - %disp(lambda) - %disp(sortyTildeNorm) - %disp('-------------') - %waitforbuttonpress; - - %disp(lambda); - %zeros = linspace(-1,max(yTildeNorm) + 1, 100); - %if length(zeros) > 3 - % disp("yeah") - %end - %gZeros = g(zeros); - %h = figure(1); - %plot(zeros, gZeros); - %hold on; - %plot([min(zeros)-10, max(zeros)+10], [0, 0]); - %waitforbuttonpress; - %close(h); - %if sum(yTildeNorm > lambda) > 2 - % disp("hooray") - %end - - - - + for i=1:obj.numVars main.y{dualNumbers(i)} = (yTildeNorm > lambda) .* main.yTilde{dualNumbers(i)} .* (1 - lambda ./ yTildeNorm); main.y{dualNumbers(i)}(yTildeNorm <= lambda) = 0; end - - - -% %first guess for main.y -% for i=1:obj.numVars -% main.y{dualNumbers(i)} = main.yTilde{dualNumbers(i)}; -% end -% yNorm = yTildeNorm; -% -% counter = 1; -% while (sum(yNorm) >= obj.factor) && (counter <= length(sortyTildeNorm)) -% lambda = sortyTildeNorm(counter); -% -% %update guess for main.y -% for i=1:obj.numVars -% main.y{dualNumbers(i)} = (yTildeNorm > lambda) .* main.yTilde{dualNumbers(i)} .* (1 - lambda ./ yTildeNorm); -% main.y{dualNumbers(i)}(yTildeNorm <= lambda) = 0; -% end -% -% yNorm = 0; -% for i=1:obj.numVars -% yNorm = yNorm + main.y{dualNumbers(i)}.^2; -% end -% yNorm = sqrt(yNorm); -% %yNorm = yTildeNorm - lambda; -% -% counter = counter + 1; -% end end end end