In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from pygam import LinearGAM, s

In [2]:
# ---------------------------------------------------------
# 1. Load the Boston housing dataset
# ---------------------------------------------------------
data = pd.read_csv('Boston.csv')
y = data['medv'].values
X = data[['lstat','rm']].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
# ---------------------------------------------------------
# Fit the Generalized Additive Model (GAM)
# ---------------------------------------------------------
# s(0, n_splines=6): smooth for LSTAT with 6 basis functions
# s(1, n_splines=9): smooth for RM with 9 basis functions

gam = LinearGAM(s(0, n_splines=6) + s(1, n_splines=9))
gam.fit(X_train, y_train)

LinearGAM(callbacks=[Deviance(), Diffs()], fit_intercept=True, 
   max_iter=100, scale=None, terms=s(0) + s(1) + intercept, 
   tol=0.0001, verbose=False)

In [14]:
# print(gam.summary())

In [4]:
# Predict on test data
pred = gam.predict(X_test)

# Compute Mean Squared Error (MSE)
mse = np.mean((y_test - pred)**2)
print("Test MSE:", mse)

Test MSE: 18.25321563975932


In [5]:
gam.lam

[[0.6], [0.6]]

In [10]:
lams = np.logspace(-3, 3, 7)
lams

array([1.e-03, 1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02, 1.e+03])

In [15]:
gam = LinearGAM(s(0, n_splines=6) + s(1, n_splines=9))
# gam.gridsearch(X_train, y_train,
#                progress=True,
#                lam=lam_values)

gam.gridsearch(X_train, y_train,
               progress=True)
print(f"Best lambda: {gam.lam}")

  0% (0 of 11) |                         | Elapsed Time: 0:00:00 ETA:  --:--:--
 72% (8 of 11) |##################       | Elapsed Time: 0:00:00 ETA:   0:00:00
100% (11 of 11) |########################| Elapsed Time: 0:00:00 Time:  0:00:00


Best lambda: [[0.015848931924611134], [0.015848931924611134]]


In [12]:
pred = gam.predict(X_test)
mse = np.mean((y_test - pred)**2)
print("Test MSE:", mse)

Test MSE: 18.210322073918814
