In [None]:
# Julia version code
"""
nonnegative linear regression 
"""

# Notes
# 1. Note we set min A_ij = 1. This can be done either by dividing A by min Aij or by just adding 1. 
# The first approach is kind of cheating because the scaling automatically makes the error small
# If min Aij is not 1, then init error can be LARGE 

#2. If all xj coordinates are updated in parallel, then comparable to scipy (update_v_parllel and update_X_paralle)

#3. If coordinatewise update, then scaling down by n is better than not scaling. 

#4. Our theoreitcal alg in experiment has occasional large jumps in error

#5. The init large error issue is fixed by running a full update step first

#6. x is very very very sparse

#7. TODO m << n case 

import Pkg; Pkg.add("Plots")
using LinearAlgebra, BenchmarkTools, Distributions, Plots

const MultivariateDistribution{S<:ValueSupport} = Distribution{Multivariate,S}

const DiscreteMultivariateDistribution   = Distribution{Multivariate, Discrete}
const ContinuousMultivariateDistribution = Distribution{Multivariate, Continuous}

function compute_scaling(A)
    scaling_vector = -1 ./sum(abs2.(A),dims=1)
    return scaling_vector
end

function init_all(epsilon, m, n, A)
    ktotal = Int(ceil(n/sqrt(epsilon))) #note that this is just an approx ktotal
    
    akm = 1/n 
    Akm = 1/n
    ak = 1/n^2
    Ak = (1+n)/n^2
    
    scaling_vector = compute_scaling(A) 
    
    randomseed=rand(Multinomial(1, ones(n)/n),1)
    jk = findall(vec(randomseed.==1))[1] #
    
    v = zeros(n) # vector with same length as x, used to obtain x
    v[jk]=1 
    xkm = zeros(n) 
    xkm[jk]=-1/scaling_vector[jk] #xkm = x_1
    
    ykm = A*xkm # ykm = y_1 
    xtildek = xkm # wildtilde{x_k}
    ybar = (n+1)*ykm # ybar = ybar_1
    
    return ybar, ykm, xtildek, ktotal, ak, akm, Ak, Akm, v, xkm, scaling_vector
end

function update_v(jk, A, ybar, v, ak) 
    v[jk]+= n*ak*(dot(A[:,jk],ybar) - 1) # ybar = ybar_{k-1}
    return v
end

function update_x(v, jk, xkm, scaling) 
    x = xkm
    x[jk]= min(max(scaling*v[jk], 0),-scaling) # n has appeared in v
    return x      
end

function update_y(n, ak, Ak, Akm, x, xkm, xtildek, A)
    xtildek = (Akm/Ak)*xtildek+(n*ak/Ak)*x-((n-1)*ak/Ak)*xkm
    y =  A*xtildek
    return xtildek, y
end

function update_ak_Ak(Ak, Akm, ak, akm, n) 
    akm = ak
    ak = min(n*ak/(n-1),sqrt(Ak)/(2*n))
    Akm = Ak
    Ak = Ak+ ak
    return Ak, Akm, ak, akm
end

function update_ybar(y, ykm, ak, akm)
    ybar = y + (akm/ak)*(y - ykm)
    return ybar 
end



function remove_col1(A,b)
    s=A'*b # n*1 
    B=A[:,vec(s.>0)] # m*b matrix where b is smaller than n
    s=s[vec(s.>0)] # s is b*1
    A=B./s'
    return A, B, s
end

[32m[1m   Updating[22m[39m registry at `C:\Users\95889\.julia\registries\General`
[32m[1m  Resolving[22m[39m package versions...
[32m[1m  Installed[22m[39m Xorg_libXext_jll ───────────── v1.3.4+4
[32m[1m  Installed[22m[39m FriBidi_jll ────────────────── v1.0.5+6
[32m[1m  Installed[22m[39m XSLT_jll ───────────────────── v1.1.33+4
[32m[1m  Installed[22m[39m Gettext_jll ────────────────── v0.20.1+7
[32m[1m  Installed[22m[39m GeometryBasics ─────────────── v0.4.1
[32m[1m  Installed[22m[39m Colors ─────────────────────── v0.12.8
[32m[1m  Installed[22m[39m Plots ──────────────────────── v1.24.3
[32m[1m  Installed[22m[39m ColorSchemes ───────────────── v3.15.0
[32m[1m  Installed[22m[39m IterTools ──────────────────── v1.4.0
[32m[1m  Installed[22m[39m MacroTools ─────────────────── v0.5.9
[32m[1m  Installed[22m[39m Wayland_jll ────────────────── v1.17.0+4
[32m[1m  Installed[22m[39m GLFW_jll ───────────────────── v3.3.4+0
[32m[1m  Instal

 [90m [0656b61e] [39m[92m+ GLFW_jll v3.3.4+0[39m
 [90m [28b8d3ca] [39m[92m+ GR v0.62.1[39m
 [90m [d2c73de3] [39m[92m+ GR_jll v0.58.1+0[39m
 [90m [5c1252a2] [39m[92m+ GeometryBasics v0.4.1[39m
 [90m [78b55507] [39m[92m+ Gettext_jll v0.20.1+7[39m
 [90m [7746bdde] [39m[92m+ Glib_jll v2.59.0+4[39m
 [90m [42e2da0e] [39m[92m+ Grisu v1.0.2[39m
 [90m [cd3eb016] [39m[92m+ HTTP v0.9.17[39m
 [90m [83e8ac13] [39m[92m+ IniFile v0.5.0[39m
 [90m [c8e1da08] [39m[92m+ IterTools v1.4.0[39m
 [90m [82899510] [39m[92m+ IteratorInterfaceExtensions v1.0.0[39m
 [90m [aacddb02] [39m[92m+ JpegTurbo_jll v2.0.1+3[39m
 [90m [c1c5ebd0] [39m[92m+ LAME_jll v3.100.0+3[39m
 [90m [dd4b983a] [39m[92m+ LZO_jll v2.10.0+3[39m
 [90m [b964fa9f] [39m[92m+ LaTeXStrings v1.3.0[39m
 [90m [23fbe1c1] [39m[92m+ Latexify v0.15.9[39m
 [90m [dd192d2f] [39m[92m+ LibVPX_jll v1.9.0+1[39m
 [90m [e9f186c6] [39m[92m+ Libffi_jll v3.2.1+4[39m
 [90m [d4300ac3] [39m[92m+ 

In [None]:
# Main loop
epsilon = 0.001 
n = 200 # input dimension 
m = 20 # Number of data 

# b can also be random and negative. m<<n.
b=rand(m,1)-repeat([0.3],m,1)
A=rand(m,n)

(A,B,s)=remove_col1(A,b)


(ybar, ykm, xtildek, ktotal, ak, akm, Ak, Akm, v, xkm, scaling_vector) = init_all(epsilon, m, n, A)
our_result = zeros(ktotal)    

for k in 2:ktotal 
        # sample jk from multinomial distribution
        randomseed=rand(Multinomial(1, ones(n)/n),1)
        jk = findall(vec(randomseed.==1))[1] # 
        
        # # update v 
        v = update_v(jk, A, ybar, v, ak) 
        
        # update x
        x = update_x(v, jk, xkm, scaling_vector[jk])
        
        # Update y based on u
        (xtildek, y) = update_y(n, ak, Ak, Akm, x, xkm, xtildek, A)
    
        #update a 
        (Ak, Akm, ak, akm) = update_ak_Ak(Ak, Akm, ak, akm, n)
    
        #update ybar 
        ybar = update_ybar(y, ykm, ak, akm)
     
        # update xkm 
        xkm = x 
        
        our_result[k] = norm(A*xtildek)^2/2-sum(xtildek)
        #our_result[k] = norm(A.dot(y),2)**2/2-sum(xsol_temp)
end

our_result1= xtildek
print(norm(A*our_result1)^2/2-sum(our_result))
f = plot(our_result)
@show f