### Stochastic Gradient Descent

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.linear_model import SGDRegressor, LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [2]:
credit = pd.read_csv('data/Credit.csv', index_col=0)

In [3]:
credit.head()

Unnamed: 0,Income,Limit,Rating,Cards,Age,Education,Gender,Student,Married,Ethnicity,Balance
1,14.891,3606,283,2,34,11,Male,No,Yes,Caucasian,333
2,106.025,6645,483,3,82,15,Female,Yes,Yes,Asian,903
3,104.593,7075,514,4,71,11,Male,No,No,Asian,580
4,148.924,9504,681,3,36,11,Female,No,No,Asian,964
5,55.882,4897,357,2,68,16,Male,No,Yes,Caucasian,331


In [4]:
X = credit[['Income', 'Limit']]
y = credit['Balance']

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

#### Fitting a basic Linear Regression model

In [6]:
lr = LinearRegression().fit(X_train, y_train)

train_mse = mean_squared_error(lr.predict(X_train), y_train)
test_mse = mean_squared_error(lr.predict(X_test), y_test)

print(train_mse)
print(test_mse)

25672.745395864164
31919.424629039902


#### Fitting a basic SGD model

In [7]:

sgd_defaults = SGDRegressor(random_state=42).fit(X_train, y_train)

train_mse = mean_squared_error(sgd_defaults.predict(X_train), y_train)
test_mse = mean_squared_error(sgd_defaults.predict(X_test), y_test)


print(sgd_defaults)
print(train_mse)
print(test_mse)

SGDRegressor(random_state=42)
3.203666114841017e+30
3.030853444775302e+30


#### Scaling the Data

In [10]:
from sklearn.preprocessing import StandardScaler

In [11]:
scaler = StandardScaler()
X_tr_scaled = scaler.fit_transform(X_train)
X_ts_scaled = scaler.transform(X_test)

In [12]:

sgd_scaled = SGDRegressor(random_state=42).fit(X_tr_scaled, y_train)

train_mse = mean_squared_error(sgd_scaled.predict(X_tr_scaled), y_train)
test_mse = mean_squared_error(sgd_scaled.predict(X_ts_scaled), y_test)

print(sgd_scaled)
print(train_mse)
print(test_mse)

SGDRegressor(random_state=42)
25685.399481208602
32139.448998350428
