In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import time

In [2]:
class lasso_nn(nn.Module):
    
    def __init__(self, p):
        super(lasso_nn, self).__init__()
        self.linear = nn.Linear(p, 1,bias = False)
        
    def forward(self, x):
        #x = x.view(-1)
        out = self.linear(x)
        return out
    
    def l1_loss(self, w):
        return torch.abs(w).sum()

In [3]:
np.random.seed(2022)
n = 100; p = 200
xmat = np.random.normal(size = (n, p)).astype("float64")
beta = np.zeros(p)
beta[:int(0.05*p)] = 1.0
y = np.dot(xmat, beta) + np.random.normal(size = n)

lamb = np.sqrt(2 * np.log(p) / n) 

In [4]:
x_torch = torch.Tensor(xmat)
y_torch = torch.Tensor(y)

device = "cuda:0"
x_dev = x_torch.to(device)
y_dev = y_torch.to(device)

network = lasso_nn(p= xmat.shape[1])
network=network.to(device)

mse_loss = nn.MSELoss()
optimizer = torch.optim.SGD(network.parameters(), lr= 0.01)

In [5]:
for epoch in range(10000):
    #print(f"Starting epoch {epoch+1}")
    
    optimizer.zero_grad()
    
    pred_y = network(x_dev)
    pred_y = pred_y.to(device)
    
    loss = mse_loss(pred_y.view(-1), y_dev)
    l1_param = []
    
    for parameter in network.parameters():
        l1_param.append(parameter.view(-1))
        
    #print(torch.cat(l1_param))
    #print(torch.cat(l1_param).size())
    
    l1 = 0.05 * network.l1_loss(torch.cat(l1_param))
    
    loss += l1
    
    loss.backward()
    optimizer.step()
    
    c_loss = loss.item()
    
    if epoch % 1000 ==  999:
        print("%d th loss is %.4f" % (epoch+1, c_loss))

1000 th loss is 0.8472
2000 th loss is 0.8099
3000 th loss is 0.8079
4000 th loss is 0.8077
5000 th loss is 0.8077
6000 th loss is 0.8076
7000 th loss is 0.8077
8000 th loss is 0.8076
9000 th loss is 0.8076
10000 th loss is 0.8076


In [6]:
param_list = []
for param in network.parameters():
    param_list.append(param)
    
param_list[0][0][:20]

tensor([ 6.3942e-01,  1.0646e+00,  8.0397e-01,  1.0232e+00,  7.0936e-01,
         7.8687e-01,  7.2832e-01,  1.0274e+00,  9.7006e-01,  1.0989e+00,
         3.8187e-05, -2.5851e-01, -2.1829e-04,  6.7066e-06, -9.9120e-05,
         3.6256e-05, -1.6807e-01,  1.0474e-04, -5.1309e-02,  1.7564e-06],
       device='cuda:0', grad_fn=<SliceBackward>)