In [1]:
import numba
import numpy as np
import time, math
from numba import njit, jit, cuda, types
from numba import int32, float32, float64
from tqdm import tqdm

In [2]:
@cuda.jit
def gemv(alpha, A, x, y, tran):
    
    Row = cuda.grid(1)
    n = A.shape[0]
    p = A.shape[1]
    pvalue = 0.
    if tran:
        if Row < p:
            for i in range(n):
                pvalue += A[i, Row] * x[i]
            y[Row] += alpha * pvalue
    else:
        if Row < n:
            for j in range(p):
                pvalue += A[Row, j] * x[j]
            y[Row] += alpha * pvalue

In [3]:
@cuda.jit
def soft_thresh(x, alpha, res):
    
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        thresh = math.fabs(x[idx]) - alpha
        res[idx] = math.copysign(thresh, x[idx]) if (thresh > 0) else 0.0

In [4]:
@cuda.jit
def axpy(alpha, x, y):
    
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        y[idx] += alpha * x[idx]

In [5]:
@cuda.jit
def vec_prod(x, y, res):
    
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        res[idx] = x[idx] * y[idx]

In [6]:
@cuda.reduce
def reduce(x, y):
    return x+y

In [7]:
def FISTA(beta, X, y, lamb, L = 10, eta = 1.2, tol = 1e-08, max_iter = 5000, dtype = np.float64):
    
    n = X.shape[0]
    p = X.shape[1]
    
    if dtype == np.float64:
        t = np.float64(1.0)
    else:
        t=  np.float32(1.0)
        
    crit = np.zeros(max_iter, dtype = dtype)
    d_bstar = np.zeros(p, dtype = dtype)
    d_beta = np.zeros(p, dtype = dtype)
    d_diff_beta2 = np.zeros(p, dtype = dtype)
    d_XTrbpd = np.zeros(p, dtype = dtype)
    d_beta_prev = np.zeros(p,dtype = dtype)
    d_beta_p = np.zeros(p, dtype = dtype)
    d_ymXbp2 = np.zeros(n, dtype = dtype)
    d_ymXb2= np.zeros(n, dtype = dtype)
    
    L_prev = L
    
    TPB = (32,1)
    BPG_p = (math.ceil(p / 32) , 1)
    BPG_n = (math.ceil(n / 32) , 1)
    
    
    
    for k in range(max_iter):
        
        ## r(beta')
        d_ymXbp = y.copy()
        gemv[BPG_n, TPB](-1.0, X, d_beta_p, d_ymXbp, False)
        vec_prod[BPG_n, TPB](d_ymXbp, d_ymXbp, d_ymXbp2)
        h_rbp = reduce(d_ymXbp2)
        
        # X^T * r(beta')
        d_XTrbp = np.zeros(p, dtype = dtype)
        gemv[BPG_p, TPB](1.0, X, d_ymXbp, d_XTrbp, True)
        
        i_k = -1
        cond = True
        
        while cond:
            i_k += 1
            L_cur = L_prev * (eta ** i_k)
            
            d_bstar = d_beta_p.copy()
            axpy[BPG_p, TPB]((1.0 / L_cur), d_XTrbp, d_bstar)
            soft_thresh[BPG_p, TPB](d_bstar, lamb / L_cur, d_beta)
            
            ## RHS
            d_diff_beta = d_beta.copy()
            axpy[BPG_p, TPB](-1.0, d_beta_p, d_diff_beta)
            vec_prod[BPG_p, TPB](d_diff_beta, d_diff_beta, d_diff_beta2)
            vec_prod[BPG_p, TPB](d_XTrbp, d_diff_beta, d_XTrbpd)
            
            RHS = L_cur * reduce(d_diff_beta2) - 2.0 * reduce(d_XTrbpd)
            
            ## LHS
            d_ymXb = y.copy()
            gemv[BPG_n, TPB](-1.0, X, d_beta, d_ymXb, False)
            vec_prod[BPG_n, TPB](d_ymXb, d_ymXb, d_ymXb2)
            h_rb = reduce(d_ymXb2)
            
            LHS =  h_rb - h_rbp
            
            cond = (LHS > RHS)
        
        L_prev = L_cur
        tnext = (1.0 + math.sqrt(1 + 4 * t**2) ) / 2.0
        
        d_diff_beta = d_beta.copy()
        axpy[BPG_p, TPB](-1.0, d_beta_prev, d_diff_beta)
        t1 = (t - 1.0) / tnext
        
        d_beta_p = d_beta.copy()
        axpy[BPG_p, TPB](t1, d_diff_beta, d_beta_p)
        
        vec_prod[BPG_p, TPB](d_diff_beta, d_diff_beta, d_diff_beta2)
        crit[k] = np.sqrt(reduce(d_diff_beta2))
        
        if crit[k] < tol:
            break
        t = tnext
        d_beta_prev = d_beta.copy()
    
    return d_beta, crit, k
    

In [13]:
np.random.seed(2022)
n = 100
p = 200

In [14]:
X = np.random.randn(n,p).astype(np.float64)
tr_beta = np.zeros(p).astype(np.float64)
tr_beta[:int(0.05*p)] = 1.0
y = np.dot(X, tr_beta) + np.random.randn(n).astype(np.float64)

loss = np.zeros(5000).astype(np.float64)## max_iter size만큼
lam =  np.sqrt(2 * np.log(p) / n)
L = np.float64(10)
eta = np.float64(2)

In [15]:
beta = np.zeros(p).astype(np.float64)

t1 = time.time()
res = FISTA(beta, X, y, lam, L, eta, tol = 1e-04, max_iter = 5000)
t2 = time.time()
print("Elapsed time: {:4f}".format(t2-t1))

Elapsed time: 10.562849


In [17]:
res[0].dtype

dtype('float64')

In [18]:
res[0][:20]

array([ 0.64070707,  1.11958681,  0.82775298,  1.11552658,  0.69531896,
        0.82641614,  0.73400349,  1.05100404,  0.91740151,  1.10698033,
        0.        , -0.26072158,  0.        ,  0.        ,  0.        ,
        0.1230212 , -0.16857445,  0.        , -0.09391906,  0.        ])

In [14]:
np.random.seed(2022)
n = 100; p = 200
xmat = np.random.normal(size = (n, p)).astype("float32")
beta = np.zeros(p, dtype = "float32")
beta[:int(0.05*p)] = 1.0
y = np.dot(xmat, beta) + np.random.normal(size = n).astype("float32")
beta_sol = np.zeros(p, dtype = "float32")

lamb = np.sqrt(2 * np.log(p) / n) 
L = np.float32(10)
eta = np.float32(2)
tol = np.float32(1e-08)

In [15]:
res = FISTA(beta_sol, xmat, y, lamb, L, eta, tol = 1e-08, max_iter = 5000, dtype = np.float32)

In [17]:
res[0].dtype

dtype('float32')

In [1]:
import numpy as np
dtype = np.float32
dtype(1.0)

1.0

In [3]:
n = 100
p = 200

import numpy as np
from numba import cuda
import math
import time


@cuda.jit
def gemv(alpha, A, x, y, tran):
    Row = cuda.grid(1)
    n = A.shape[0]
    p = A.shape[1]
    pvalue = 0.
    if tran:
        if Row < p:
            for i in range(n):
                pvalue += A[i, Row] * x[i]
            y[Row] += alpha * pvalue
    else:
        if Row < n:
            for j in range(p):
                pvalue += A[Row, j] * x[j]
            y[Row] += alpha * pvalue

@cuda.jit
def soft_thr(x, alpha, S):
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        thr = math.fabs(x[idx]) - alpha
        if thr > 0:
            S[idx] = math.copysign(thr, x[idx])
        else:
            S[idx] = 0

@cuda.jit
def axpy(alpha, x, y):
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        y[idx] += alpha * x[idx]

        
@cuda.jit
def vec_prod(x, y, res):
    idx = cuda.grid(1)
    if idx < x.shape[0]:
        res[idx] = x[idx] * y[idx]

@cuda.reduce
def reduce(x, y):
    return x+y

def FISTA(beta, X, y, lamb, L = 10, eta = 1.2, tol = 1e-08, 
          max_iter = 5000, dtype = np.float32):
    n = X.shape[0]
    p = X.shape[1]
    t = dtype(1.0)
    crit = np.zeros(max_iter, dtype = dtype)
    d_bstar = np.zeros(p, dtype = dtype)
    d_beta = np.zeros(p, dtype = dtype)
    d_diff_beta2 = np.zeros(p, dtype = dtype)
    d_XTrbpd = np.zeros(p, dtype = dtype)
    d_beta_prev = np.zeros(p,dtype = dtype)
    d_beta_p = np.zeros(p, dtype = dtype)
    d_ymXbp2 = np.zeros(n, dtype = dtype)
    d_ymXb2= np.zeros(n, dtype = dtype)
    L_prev = L
    TPB = (32,1)
    BPG_p = (math.ceil(p / 32) , 1)
    BPG_n = (math.ceil(n / 32) , 1)
    
    for k in range(max_iter):
        d_ymXbp = y.copy()
        gemv[BPG_n, TPB](-1.0, X, d_beta_p, d_ymXbp, False)
        vec_prod[BPG_n, TPB](d_ymXbp, d_ymXbp, d_ymXbp2)
        h_rbp = reduce(d_ymXbp2)
        d_XTrbp = np.zeros(p, dtype = dtype)
        gemv[BPG_p, TPB](1.0, X, d_ymXbp, d_XTrbp, True)
        
        i_k = -1
        cond = True
        while cond:
            i_k += 1
            L_cur = L_prev * (eta ** i_k)
            d_bstar = d_beta_p.copy()
            axpy[BPG_p, TPB]((1.0 / L_cur), d_XTrbp, d_bstar)
            soft_thr[BPG_p, TPB](d_bstar, lamb / L_cur, d_beta)
            d_diff_beta = d_beta.copy()
            axpy[BPG_p, TPB](-1.0, d_beta_p, d_diff_beta)
            vec_prod[BPG_p, TPB](d_diff_beta, d_diff_beta, d_diff_beta2)
            vec_prod[BPG_p, TPB](d_XTrbp, d_diff_beta, d_XTrbpd)
            RHS = L_cur * reduce(d_diff_beta2) - 2.0 * reduce(d_XTrbpd)
            d_ymXb = y.copy()
            gemv[BPG_n, TPB](-1.0, X, d_beta, d_ymXb, False)
            vec_prod[BPG_n, TPB](d_ymXb, d_ymXb, d_ymXb2)
            h_rb = reduce(d_ymXb2)
            LHS =  h_rb - h_rbp
            cond = (LHS > RHS)
        
        L_prev = L_cur
        tnext = (1.0 + np.sqrt(1 + 4 * t**2) ) / 2.0
        d_diff_beta = d_beta.copy()
        axpy[BPG_p, TPB](-1.0, d_beta_prev, d_diff_beta)
        t1 = (t - 1.0) / tnext
        d_beta_p = d_beta.copy()
        axpy[BPG_p, TPB](t1, d_diff_beta, d_beta_p)
        vec_prod[BPG_p, TPB](d_diff_beta, d_diff_beta, d_diff_beta2)
        crit[k] = np.sqrt(reduce(d_diff_beta2))
        if crit[k] < tol:
            break
        t = tnext
        d_beta_prev = d_beta.copy()
    return d_beta, crit, k

np.random.seed(2022)
X = np.random.randn(n,p).astype(np.float64)
tr_beta = np.zeros(p).astype(np.float64)
tr_beta[:int(0.05*p)] = 1.0
y = np.dot(X, tr_beta) + np.random.randn(n).astype(np.float64)

niter = 11
comp_time_numba = np.zeros(niter)

for i in np.arange(niter):

    beta = np.zeros(p, dtype = np.float64)
    lam = np.sqrt(2*np.log(p)/n).astype(np.float64)
    L = np.float64(10)
    eta = np.float64(2)
    
    t1 = time.time()
    out = FISTA(beta, X, y, lam, L, eta, tol = 1e-04, max_iter = 5000, dtype=np.float64)
    t2 = time.time()
    comp_time_numba[i] = t2-t1
    
    print('\n ',i+1,'-th iteration, Collapsed Time: ', comp_time_numba[i])
    



  1 -th iteration, Collapsed Time:  22.553854942321777


In [4]:
out

(array([ 6.31984770e-01,  1.11394525e+00,  8.37214291e-01,  1.12377369e+00,
         6.83480203e-01,  8.27868700e-01,  7.39624679e-01,  1.05381560e+00,
         9.09702361e-01,  1.10813928e+00,  0.00000000e+00, -2.69272387e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.32332742e-01,
        -1.59032449e-01,  0.00000000e+00, -9.83042866e-02,  0.00000000e+00,
         0.00000000e+00, -1.24380991e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  9.49611440e-02,
        -4.08564329e-01,  9.02729854e-02, -1.68434799e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.84062183e-01,  0.00000000e+00,
         1.75486445e-01,  0.00000000e+00,  0.00000000e+00,  5.76819442e-02,
         3.21197324e-02,  1.68470517e-02,  0.00000000e+00, -5.49789406e-02,
         0.00000000e+00,  1.42752349e-01,  0.00000000e+00,  0.00000000e+00,
         5.80113344e-02, -1.27389178e-01, -4.92529050e-02,  0.00000000e+00,
         5.1