In [1]:
import torch
import numpy as np
import time

In [2]:
print(torch.__version__)

1.7.1


In [3]:
a = torch.tensor((1, 2, -1))
b = torch.tensor((3, 0, 4))
torch.maximum(a, b)

tensor([3, 2, 4])

In [7]:
def soft_thr(x, alpha):
    n = x.shape[0]
    S = torch.maximum(x-alpha, 
        torch.zeros(n,device='cuda:0'))*torch.sign(x)
    return S

In [8]:
def FISTA(beta, X, y, lamb, L, eta , tol = 1e-04, max_iter = 5000, dtype = torch.float32):
    if(dtype == torch.float32):
        torch.set_default_tensor_type(torch.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.DoubleTensor)
    device = torch.device("cuda:0")
    n = X.shape[0]
    p = X.shape[1]
    dbeta = torch.Tensor(beta).to(device)
    dX = torch.Tensor(X).to(device)
    dy = torch.Tensor(y).to(device)
    t = torch.ones(1,dtype=dtype, device=device)
    crit = np.zeros(max_iter)
    dbeta_p = torch.Tensor(beta).to(device)
    dbeta_prev = torch.Tensor(beta).to(device)
    L_prev = L
    for k in range(max_iter):
        dymXbp = dy-torch.matmul(dX,dbeta_p)
        drbp = torch.dot(dymXbp, dymXbp)
        dXTrbp = torch.matmul(torch.t(dX),dymXbp)

        i_k = -1
        cond = True
        while cond:
            i_k += 1
            L_cur = L_prev*(eta**i_k)
            dbstar = dbeta_p + dXTrbp/L_cur
            dbeta = soft_thr(dbstar, lamb/L_cur)
            diff_beta = dbeta - dbeta_p
            RHS_1st = torch.dot(diff_beta,diff_beta)
            RHS_2nd = torch.dot(diff_beta,dXTrbp)
            RHS = L_cur*RHS_1st-2.0*RHS_2nd
            dymXb = dy-torch.matmul(dX,dbeta)
            LHS = torch.dot(dymXb,dymXb)-drbp
            cond = (LHS>RHS)
        L_prev = L_cur
        tnext = (1.0+torch.sqrt(1+4*t**2))/2.0
        diff_beta = dbeta-dbeta_prev
        t1 = (t-1.0)/tnext
        dbeta_p = dbeta+t1*diff_beta
        crit[k] = torch.norm(diff_beta)
        if crit[k] < tol:
            break
        t = tnext
        dbeta_prev = dbeta
    out = dbeta.to('cpu')
    return out.numpy(), crit, k

In [9]:
niter = 1
comp_time_torch = np.zeros((niter,2))

n = 100; p = 200
for i in range(niter):

    xmat = np.random.normal(size = (n, p)).astype("float64")
    beta = np.zeros(p)
    beta[:int(0.05*p)] = 1.0
    y = np.dot(xmat, beta) + np.random.normal(size = n)
    lamb = np.float64(np.sqrt(2 * np.log(p) / n) )
    L = np.float64(10)
    eta = np.float64(2)
    tol = np.float64(1e-04)

    beta_sol = np.zeros(p, dtype = "float64")

    t1 = time.time()
    res = FISTA(beta_sol, xmat, y, lamb, L, eta, tol = tol, max_iter = 5000,dtype=torch.float64)
    t2 = time.time()

    comp_time_torch[i,0] = t2 - t1


In [1]:
import torch
import numpy as np
import time

def soft_thr(x, alpha):
    n = x.shape[0]
    S = torch.maximum(x-alpha, 
        torch.zeros(n,device='cuda:0'))*torch.sign(x)
    return S

def FISTA(beta, X, y, lamb, L, eta , tol = 1e-04, max_iter = 5000, dtype = torch.float32):
    if(dtype == torch.float32):
        torch.set_default_tensor_type(torch.FloatTensor)
    else:
        torch.set_default_tensor_type(torch.DoubleTensor)
    device = torch.device("cuda:0")
    n = X.shape[0]
    p = X.shape[1]
    dbeta = torch.Tensor(beta).to(device)
    dX = torch.Tensor(X).to(device)
    dy = torch.Tensor(y).to(device)
    t = torch.ones(1,dtype=dtype, device=device)
    crit = np.zeros(max_iter)
    dbeta_p = torch.Tensor(beta).to(device)
    dbeta_prev = torch.Tensor(beta).to(device)
    L_prev = L
    for k in range(max_iter):
        dymXbp = dy-torch.matmul(dX,dbeta_p)
        drbp = torch.dot(dymXbp, dymXbp)
        dXTrbp = torch.matmul(torch.t(dX),dymXbp)

        i_k = -1
        cond = True
        while cond:
            i_k += 1
            L_cur = L_prev*(eta**i_k)
            dbstar = dbeta_p + dXTrbp/L_cur
            dbeta = soft_thr(dbstar, lamb/L_cur)
            diff_beta = dbeta - dbeta_p
            RHS_1st = torch.dot(diff_beta,diff_beta)
            RHS_2nd = torch.dot(diff_beta,dXTrbp)
            RHS = L_cur*RHS_1st-2.0*RHS_2nd
            dymXb = dy-torch.matmul(dX,dbeta)
            LHS = torch.dot(dymXb,dymXb)-drbp
            cond = (LHS>RHS)
        L_prev = L_cur
        tnext = (1.0+torch.sqrt(1+4*t**2))/2.0
        diff_beta = dbeta-dbeta_prev
        t1 = (t-1.0)/tnext
        dbeta_p = dbeta+t1*diff_beta
        crit[k] = torch.norm(diff_beta)
        if crit[k] < tol:
            break
        t = tnext
        dbeta_prev = dbeta
    out = dbeta.to('cpu')
    return out.numpy(), crit, k

n = 100
p = 200

np.random.seed(2022)
X = np.random.randn(n,p).astype(np.float64)
tr_beta = np.zeros(p).astype(np.float64)
tr_beta[:int(0.05*p)] = 1.0
y = np.dot(X, tr_beta) + np.random.randn(n).astype(np.float64)

niter = 11
comp_time_torch = np.zeros(niter)

for i in np.arange(niter):

    beta = np.zeros(p, dtype = np.float64)
    lam = np.sqrt(2*np.log(p)/n).astype(np.float64)
    L = np.float64(10)
    eta = np.float64(2)
    
    t1 = time.time()
    out = FISTA(beta, X, y, lam, L, eta, tol = 1e-04, max_iter = 5000, dtype=torch.float64)
    t2 = time.time()
    comp_time_torch[i] = t2-t1
    
    print('\n ',i+1,'-th iteration, Collapsed Time: ', comp_time_torch[i])
   




  1 -th iteration, Collapsed Time:  2.1115055084228516

  2 -th iteration, Collapsed Time:  0.23238754272460938

  3 -th iteration, Collapsed Time:  0.22502660751342773

  4 -th iteration, Collapsed Time:  0.22482514381408691

  5 -th iteration, Collapsed Time:  0.26065564155578613

  6 -th iteration, Collapsed Time:  0.27126145362854004

  7 -th iteration, Collapsed Time:  0.22214579582214355

  8 -th iteration, Collapsed Time:  0.2237236499786377

  9 -th iteration, Collapsed Time:  0.22644615173339844

  10 -th iteration, Collapsed Time:  0.2279040813446045

  11 -th iteration, Collapsed Time:  0.22399330139160156
