In this notebook we implement lasso in pytorch using fista:

In [1]:
import torch
from DeePyMoD_SBL.data import Burgers
import numpy as np

In [2]:
x = np.linspace(-2, 5, 75)
t = np.linspace(0.5, 5.0, 25)

x_grid, t_grid = np.meshgrid(x, t, indexing='ij')

In [3]:
dataset = Burgers(0.1, 1.0)

In [4]:
library = dataset.library(x_grid.reshape(-1, 1), t_grid.reshape(-1, 1))
time_deriv = dataset.time_deriv(x_grid.reshape(-1, 1), t_grid.reshape(-1, 1))

In [5]:
X = torch.tensor(library / np.linalg.norm(library, axis=0, keepdims=True), dtype=torch.float32)
y = torch.tensor(time_deriv / np.linalg.norm(time_deriv), dtype=torch.float32)

In [6]:
torch.inverse(X.T @ X) @ (X.T @ y)

tensor([[ 5.9456e-06],
        [-1.9073e-06],
        [ 7.9301e-01],
        [-6.1989e-05],
        [-2.6554e-05],
        [-8.3787e-01],
        [ 1.6403e-04],
        [ 8.3923e-05],
        [ 4.4227e-05],
        [-3.8147e-06],
        [-1.0300e-04],
        [-4.5776e-05]])

In [7]:
torch.finfo(torch.float32).eps

1.1920928955078125e-07

In [8]:
torch.finfo(torch.float64).eps

2.220446049250313e-16

# FISTA according to rudy

In [46]:
L = torch.norm(X.T @ X) # Note: L is different from Rudy

In [99]:
max_its = 100000
w = torch.zeros((X.shape[1], 1))
w_old = torch.zeros_like(w)
l1 = 1e-5
threshold = 1e-8

In [100]:
converged = False
while not converged:
    z = w + it / (it + 1) * (w - w_old)
    w_old = w
    z = z - X.T @ (X @ z - y) / L
    w = torch.sign(z) * torch.max(torch.abs(z) - l1 / L, torch.zeros_like(z))
    converged = (torch.max(torch.abs(w - w_old)) < threshold).item()

In [101]:
w

tensor([[-0.0000e+00],
        [-1.2646e-05],
        [ 7.9304e-01],
        [ 0.0000e+00],
        [-0.0000e+00],
        [-8.3782e-01],
        [ 0.0000e+00],
        [ 3.1286e-06],
        [-0.0000e+00],
        [-0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00]])

In [98]:
w

tensor([[-0.0000e+00],
        [-2.3510e-05],
        [ 7.9304e-01],
        [ 0.0000e+00],
        [-0.0000e+00],
        [-8.3780e-01],
        [ 0.0000e+00],
        [ 5.0938e-06],
        [-0.0000e+00],
        [-1.1717e-05],
        [ 0.0000e+00],
        [ 0.0000e+00]])

In [85]:
(torch.max(torch.abs(w - w_old)) < threshold).item() == True

True

# ISTA

In [334]:
l1 = 1e-5
threshold = 1e-7

In [335]:
L = 2 * torch.symeig(X.T @ X).eigenvalues.max()

In [336]:
def S(alpha, x): 
    return torch.max(torch.abs(x) - alpha, torch.zeros_like(x)) * torch.sign(x)

In [337]:
w = torch.zeros((X.shape[1], 1), dtype=torch.float32)

In [338]:
converged = False
while not converged:
    w_old = w
    w = S(l1/L, w_old - 2/ L * X.T @ (X @ w_old - y))
    converged = (torch.max(torch.abs(w - w_old)) < threshold).item()

In [339]:
w

tensor([[-0.0000e+00],
        [-1.4649e-05],
        [ 7.9305e-01],
        [ 0.0000e+00],
        [-0.0000e+00],
        [-8.3782e-01],
        [ 0.0000e+00],
        [ 2.8472e-06],
        [-7.4072e-09],
        [-1.0197e-05],
        [ 0.0000e+00],
        [ 0.0000e+00]])

So that seems to work :)

# FISTA

In [394]:
L = 2 * torch.symeig(X.T @ X).eigenvalues.max()

In [395]:
def S(alpha, x): 
    return torch.max(torch.abs(x) - alpha, torch.zeros_like(x)) * torch.sign(x)

In [426]:
l1 = torch.tensor(1e-5)
threshold = torch.tensor(torch.finfo(torch.float32).eps) # machine precision

In [427]:
w = torch.zeros((X.shape[1], 1), dtype=torch.float32)
z = w

t = torch.tensor(1.0)

In [12]:
lasso(X, y, torch.tensor(1e-5))

tensor([[-0.0000e+00],
        [-8.4144e-06],
        [ 7.9305e-01],
        [ 0.0000e+00],
        [-0.0000e+00],
        [-8.3783e-01],
        [ 0.0000e+00],
        [ 2.4729e-06],
        [-0.0000e+00],
        [-0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00]])