In [1]:
import torch
import numpy as np
import time

In [2]:
def soft_thresh(x, alpha):
    device = torch.device("cuda:0")
    x = x.to(device)
    n = x.shape[0]
    res = torch.maximum(x - alpha, torch.zeros(n).to(device)) * torch.sign(x)
    
    return res

In [3]:
def FISTA(beta, X, y, lamb, L, eta , tol = 1e-08, max_iter = 5000, dtype = torch.float64):
    
    if(dtype == torch.float64):
        torch.set_default_tensor_type(torch.DoubleTensor)
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
    
    device = torch.device("cuda:0")
    
    n = X.shape[0]
    p = X.shape[1]
    
    
    beta = torch.Tensor(beta).to(device)
    X = torch.Tensor(X).to(device)
    y = torch.Tensor(y).to(device)
    
    XT = X.transpose(0,1)
    XT = XT.to(device)
    
    
    t = torch.ones(1, dtype = dtype).to(device)
    crit = np.zeros(max_iter)
    beta_p = torch.zeros(p, dtype =dtype, device = device)
    beta_prev = torch.zeros(p, dtype =dtype, device = device)
    
    L_prev = L
    
    for k in range(max_iter):
        ymXbp = y - X@beta_p
        rbp = torch.dot(ymXbp, ymXbp)
        
        XTrbp = XT@ymXbp
        
        i_k = -1
        cond = True
        while cond:
            i_k += 1
            L_cur = L_prev * (eta ** i_k)
            
            bstar = beta_p + XTrbp / L_cur
            d_beta = soft_thresh(bstar, lamb / L_cur)
            
            #RHS
            diff_beta = d_beta - beta_p
            RHS_1st = torch.dot(diff_beta, diff_beta)
            RHS_2nd = torch.dot(diff_beta, XTrbp)
            
            RHS = L_cur * RHS_1st - 2.0 * RHS_2nd
            
            #LHS
            ymXb = y - X@d_beta
            LHS = torch.dot(ymXb, ymXb)  - rbp
            
            cond = (LHS > RHS)
        
        L_prev = L_cur
        tnext =  (1.0 + torch.sqrt(1 + 4 * t**2)) / 2.0
        diff_beta = d_beta - beta_prev
        
        t1 = (t - 1.0 ) / tnext
        beta_p = d_beta + t1 * diff_beta
        
        crit[k] = torch.norm(diff_beta)
        
        if crit[k] < tol:
            break
        
        t = tnext
        beta_prev = d_beta
    
    res = d_beta.to('cpu')
    
    
    return res.numpy(), crit, k

## Double precision test

In [4]:
np.random.seed(2022)
n = 100; p = 200
xmat = np.random.normal(size = (n, p)).astype("float64")
beta = np.zeros(p)
beta[:int(0.05*p)] = 1.0
y = np.dot(xmat, beta) + np.random.normal(size = n)
beta_sol = np.zeros(p, dtype = "float64")

lamb = np.float64(np.sqrt(2 * np.log(p) / n) )
L = np.float64(10)
eta = np.float64(2)
tol = np.float64(1e-04)

In [9]:
niter = 11
comp_time_torch = np.zeros((niter,2))

n = 500; p = 1000
for i in range(niter):
    
    xmat = np.random.normal(size = (n, p)).astype("float64")
    beta = np.zeros(p)
    beta[:int(0.05*p)] = 1.0
    y = np.dot(xmat, beta) + np.random.normal(size = n)
    lamb = np.float64(np.sqrt(2 * np.log(p) / n) )
    L = np.float64(10)
    eta = np.float64(2)
    tol = np.float64(1e-04)
    
    beta_sol = np.zeros(p, dtype = "float64")
    
    t1 = time.time()
    res = FISTA(beta_sol, xmat, y, lamb, L, eta, tol = tol, max_iter = 5000)
    t2 = time.time()
    
    comp_time_torch[i,0] = t2 - t1
    
    xmat = np.random.normal(size = (n, p)).astype("float32")
    beta = np.zeros(p, dtype = np.float32)
    beta[:int(0.05*p)] = 1.0
    y = np.dot(xmat, beta) + np.random.normal(size = n).astype(np.float32)
    beta_sol = np.zeros(p, dtype = "float32")
    
    lamb = np.float32(np.sqrt(2 * np.log(p) / n) )
    L = np.float32(10)
    eta = np.float32(2)
    tol = np.float32(1e-04)
    
    t1 = time.time()
    res = FISTA(beta_sol, xmat, y, lamb, L, eta, tol = tol, max_iter = 5000, dtype = torch.float32)
    t2 = time.time()
    
    comp_time_torch[i, 1] = t2 - t1
    
    print("\n", i+1, "-th iteration, Elapsed time(double): ", comp_time_torch[i,0],
         ', Elapsed time(single): ', comp_time_torch[i,1])
    
    
    
    


 1 -th iteration, Elapsed time(double):  0.4206995964050293 , Elapsed time(single):  0.47515153884887695

 2 -th iteration, Elapsed time(double):  0.4394409656524658 , Elapsed time(single):  0.4901621341705322

 3 -th iteration, Elapsed time(double):  0.4497029781341553 , Elapsed time(single):  0.5435874462127686

 4 -th iteration, Elapsed time(double):  0.5258321762084961 , Elapsed time(single):  0.5039727687835693

 5 -th iteration, Elapsed time(double):  0.49323415756225586 , Elapsed time(single):  0.46979451179504395

 6 -th iteration, Elapsed time(double):  0.4060664176940918 , Elapsed time(single):  0.4055190086364746

 7 -th iteration, Elapsed time(double):  0.5602684020996094 , Elapsed time(single):  0.5138263702392578

 8 -th iteration, Elapsed time(double):  0.4467203617095947 , Elapsed time(single):  0.4526200294494629

 9 -th iteration, Elapsed time(double):  0.5377271175384521 , Elapsed time(single):  0.44888997077941895

 10 -th iteration, Elapsed time(double):  0.448929

In [6]:
res[0][:20]

array([1.0846289 , 1.1303933 , 0.9725525 , 0.81222254, 0.84910697,
       0.8344582 , 0.9020498 , 1.1550989 , 0.6405729 , 0.83180326,
       1.1096005 , 1.0489771 , 0.88255185, 1.273143  , 1.2337017 ,
       1.0239142 , 1.1536824 , 1.2027469 , 0.90862584, 0.93551725],
      dtype=float32)

## Single precision test

In [17]:
np.random.seed(2022)
n = 100; p = 200
xmat = np.random.normal(size = (n, p)).astype("float32")
beta = np.zeros(p, dtype = np.float32)
beta[:int(0.05*p)] = 1.0
y = np.dot(xmat, beta) + np.random.normal(size = n).astype(np.float32)
beta_sol = np.zeros(p, dtype = "float32")

lamb = np.float32(np.sqrt(2 * np.log(p) / n) )
L = np.float32(10)
eta = np.float32(2)
tol = np.float32(1e-04)

In [18]:
t1 = time.time()
res = FISTA(beta_sol, xmat, y, lamb, L, eta, tol = tol, max_iter = 5000, dtype = torch.float32)
t2 = time.time()
print("Elapsed time: {:4f}".format(t2-t1))

Elapsed time: 0.164206


In [19]:
res[0][:20]

array([1.068292  , 0.9096012 , 1.2106907 , 1.124822  , 1.1573865 ,
       1.0303564 , 0.8849843 , 0.8242552 , 1.5445442 , 1.0892318 ,
       0.        , 0.        , 0.33762506, 0.        , 0.07564396,
       0.5808942 , 0.        , 0.        , 0.        , 0.        ],
      dtype=float32)