Lasso Regression

In [2]:
import numpy as np 
import pandas as pd
from sklearn.datasets import load_diabetes

In [3]:
data = load_diabetes()

In [4]:
x = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target, name='Disease Progression')

In [5]:
from sklearn.model_selection import train_test_split as test_train_split
x_train, x_test, y_train, y_test = test_train_split(x,y, test_size=0.2, random_state=42)

In [6]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

x_train_scaled = scaler.fit_transform(x_train)
x_test_scaled = scaler.transform(x_test)
from sklearn.linear_model import Lasso
lasso = Lasso(alpha=0.1)
lasso.fit(x_train_scaled, y_train)
y_pred = lasso.predict(x_test_scaled)
from sklearn.metrics import mean_squared_error, r2_score
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)   
print(f'Mean Squared Error: {mse}')
print(f'R^2 Score: {r2}')
print('Lasso Coefficients:')
for feature, coef in zip(x.columns, lasso.coef_):
    print(f'{feature}: {coef}') 
    

Mean Squared Error: 2884.624288735213
R^2 Score: 0.455541399027904
Lasso Coefficients:
age: 1.7304505591468704
sex: -11.316359105598492
bmi: 25.82462698851206
bp: 16.64425156281702
s1: -29.35841191258418
s2: 13.275844111235715
s3: 0.5479478963570505
s4: 10.236168047144627
s5: 29.63282610520826
s6: 2.3934752147470766


In [8]:
lasso_coef = pd.DataFrame({
    "Feature": x.columns,
    "Weight": lasso.coef_
})

lasso_coef[lasso_coef["Weight"] != 0]


Unnamed: 0,Feature,Weight
0,age,1.730451
1,sex,-11.316359
2,bmi,25.824627
3,bp,16.644252
4,s1,-29.358412
5,s2,13.275844
6,s3,0.547948
7,s4,10.236168
8,s5,29.632826
9,s6,2.393475


In [9]:
from sklearn.linear_model import Ridge

ridge = Ridge(alpha=1.0)
ridge.fit(x_train_scaled, y_train)

comparison = pd.DataFrame({
    "Feature": x.columns,
    "Ridge Weight": ridge.coef_,
    "Lasso Weight": lasso.coef_
})

comparison


Unnamed: 0,Feature,Ridge Weight,Lasso Weight
0,age,1.807342,1.730451
1,sex,-11.44819,-11.316359
2,bmi,25.732699,25.824627
3,bp,16.7343,16.644252
4,s1,-34.671954,-29.358412
5,s2,17.053075,13.275844
6,s3,3.369914,0.547948
7,s4,11.76426,10.236168
8,s5,31.378384,29.632826
9,s6,2.458139,2.393475
