Modify the regression scratch code in our lecture such that:

- Implement early stopping in which if the absolute difference between old loss and new loss does not exceed certain threshold, we abort the learning.

- Implement options for stochastic gradient descent in which we use only one sample for training.  Make sure that sample does not repeat unless all samples are read at least once already.

- Put everything into class.

In [3]:
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np

boston = load_boston()
X = boston.data
y = boston.target
m = X.shape[0]  #number of samples
n = X.shape[1]  #number of features

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)

# actually you can do like this too
# X = np.insert(X, 0, 1, axis=1)
intercept = np.ones((X_train.shape[0], 1))
X_train = np.concatenate((intercept, X_train), axis=1)
intercept = np.ones((X_test.shape[0], 1))
X_test = np.concatenate((intercept, X_test), axis=1)

class LinearRegression:
    # if batch, set alpha to smaller values
    def __init__(self, alpha=0.001, max_iter=10000, 
            loss_old=10000, tol=1e-5, method="batch"):
        self.alpha = alpha
        self.max_iter = max_iter
        self.loss_old = loss_old
        self.tol = tol
        self.method = method
        
    def fit(self, X, y):
        self.theta = np.zeros(X.shape[1])
        iter_stop = 0
        list_of_used_ix = [] #<===without replacement
        
        for i in range(self.max_iter):
            
            if self.method != "batch":
                i = np.random.randint(X.shape[0])
                while i in list_of_used_ix:
                    i = np.random.randint(X.shape[0])
                X_train = X[i, :].reshape(1, -1)
                y_train = y[i]
                list_of_used_ix.append(i)
                print(list_of_used_ix)
                if len(list_of_used_ix) == X.shape[0]:
                    list_of_used_ix = []
            else:
                X_train = X
                y_train = y
            
            yhat = self.h_theta(X_train)
            error = yhat - y_train
            
            # early stopping
            loss_new = self.mse(yhat, y_train)
            if self.delta_loss(loss_new, self.loss_old, self.tol):  #np.allclose
                iter_stop = i
                break
            self.loss_old = loss_new

            grad = self.gradient(X_train, error)
            self.theta = self.theta - self.alpha * grad

    # can name it predict for easy understanding
    def h_theta(self, X):
        return X @ self.theta

    def mse(self, yhat, y):
        return ((yhat - y)**2 / yhat.shape[0]).sum()

    def delta_loss(self, loss_new, loss_old, tol):
        return np.abs(loss_new - loss_old) < tol

    def gradient(self, X, error):
        return X.T @ error

model = LinearRegression(method="sto") #<==try put method="sto"
model.fit(X_train, y_train)
yhat = model.h_theta(X_test)
mse = model.mse(yhat, y_test)

# print the mse
print("MSE: ", mse)

[201]
[201, 262]
[201, 262, 77]
[201, 262, 77, 177]
[201, 262, 77, 177, 158]
[201, 262, 77, 177, 158, 190]
[201, 262, 77, 177, 158, 190, 224]
[201, 262, 77, 177, 158, 190, 224, 33]
[201, 262, 77, 177, 158, 190, 224, 33, 199]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 223]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 223, 307]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 223, 307, 128]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 223, 307, 128, 79]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 223, 307, 128, 79, 46]
[201, 262, 77, 177, 158, 190, 224, 33, 199, 323, 173, 258, 175, 62, 2

[323, 250, 51, 183, 96, 93, 179, 112, 256, 26, 274, 154, 119, 287, 171, 133, 4, 157, 104, 61, 160, 28, 292, 200, 332, 219, 230, 203, 146, 276, 7, 234, 29, 83, 80, 10, 187, 85, 333, 302, 334, 214, 31, 235, 100, 310, 284, 168, 254, 318, 169, 216, 231, 136, 59, 135, 79, 15, 145, 56, 143, 178, 156, 50, 98, 315, 220, 300, 45, 288, 311, 280, 242, 265, 217, 201, 335, 170, 329, 262, 121, 141, 282, 137, 266, 241, 33, 153, 94, 321, 132, 324, 236, 40, 48, 184, 90, 62, 41, 140, 12, 163, 299, 285, 232, 202, 259, 46, 182, 296, 25, 295, 345, 350, 113, 111, 27, 76, 165, 150, 225, 55, 294, 352, 18, 14, 338, 126, 243, 326, 130, 245, 193, 281, 247, 215, 54, 120, 67, 192, 13, 209, 57, 151, 19, 75, 16, 97, 255, 307, 102, 261, 233, 319, 305, 308, 336, 331, 144, 20, 103, 293, 68, 340, 218, 244, 227, 211, 64, 290, 249, 24, 139, 207, 127, 322, 21, 114, 42, 267, 188, 149, 47, 91, 263, 101, 271, 0, 65, 349, 264, 181, 185, 221, 116, 325, 317, 158, 8, 174, 166, 110, 128, 258, 198, 257, 223, 206, 66, 224, 298, 118,

[126, 238, 317, 326, 318, 256, 61, 152, 81, 115, 86, 129, 349, 73, 118, 52, 276, 96, 301, 82, 67, 210, 284, 98, 40, 229, 91, 233, 55, 69, 139, 97, 168, 331, 230, 24, 266, 74, 262, 138, 322, 26, 232, 65, 289, 46, 132, 332, 193, 79, 211, 21, 264, 0, 275, 68, 239, 241, 293, 324, 223, 37, 288, 226, 234, 159, 254, 194, 50, 130, 344, 348, 56, 110, 306, 42, 240, 11, 158, 243, 104, 268, 292, 170, 90, 323, 149, 72, 99, 258, 286, 15, 222, 219, 248, 352, 175, 285, 135, 308, 16, 242, 216, 100, 107, 261, 171, 176, 246, 36, 20, 328, 4, 350, 18, 213, 116, 19, 78, 291, 45, 183, 102, 309, 31, 155, 150, 38, 203, 320, 298, 197, 48, 106, 209, 181, 255, 94, 263, 17, 201, 119, 165, 225, 190, 70, 63, 3, 267, 346, 84, 12, 131, 133, 337, 259, 299, 162, 313, 271, 177, 218, 58, 113, 66, 141, 13, 297, 231, 108, 151, 186, 319, 307, 109, 196, 236, 27, 345, 339, 174, 76, 60, 8, 310, 327, 9, 154, 30, 278, 277, 221, 111, 157, 296, 123, 125, 250, 59, 224, 29, 54, 112, 206, 101, 302, 281, 180, 279, 195, 198, 188, 244]
[

[351, 282, 297, 111, 58, 92, 87, 129, 150, 316, 102, 57, 269, 290, 9, 127, 61, 210, 198, 134, 268, 344, 246, 27, 287, 229, 25, 223, 186, 91, 66, 276, 215, 348, 278, 244, 249, 30, 42, 319, 79, 300, 169, 136, 328, 286, 151, 309, 173, 205, 237, 211, 38, 11, 36, 329, 56, 93, 218, 310, 232, 109, 291, 69, 242, 120, 35, 303, 39, 146, 340, 294, 265, 157, 299, 175, 222, 353, 332, 59, 327, 220, 137, 347, 208, 274, 180, 275, 346, 99, 41, 1, 105, 333, 201, 257, 52, 10, 239, 90, 14, 62, 219, 46, 203, 74, 152, 283, 135, 26, 345, 317, 24, 164, 262, 240, 128, 324, 21, 43, 308, 171, 28, 331, 13, 115, 18, 295, 241, 233, 188, 140, 183, 302, 124, 187, 161, 197, 130, 281, 315, 149, 142, 122, 131, 338, 256, 68, 22, 225, 209, 216, 217, 85, 139, 337, 323, 133, 226, 63, 259, 153, 221, 100, 44, 155, 272, 318, 204, 81, 49, 212, 190, 243, 159, 258, 184, 296, 170, 80, 32, 64, 178, 29, 8, 266, 117, 277, 31, 270, 289, 168, 156, 292, 84, 350, 71, 113, 339, 336, 112, 144, 207, 224, 132, 143, 167, 253, 177, 72, 83, 15,

[89, 191, 125, 159, 231, 351, 88, 59, 151, 70, 247, 78, 293, 41, 32, 65, 184, 176, 207, 224, 108, 302, 208, 140, 349, 178, 164, 110, 53, 57, 135, 214, 155, 303, 235, 181, 137, 16, 295, 80, 22, 344, 144, 69, 234, 131, 179, 233, 93, 60, 172, 31, 273, 257, 175, 203, 28, 133, 71, 35, 40, 210, 326, 56, 268, 54, 105, 313, 260, 269, 68, 256, 202, 97, 258, 118, 73, 291, 17, 274, 288, 283, 337, 44, 63, 157, 290, 4, 339, 282, 309, 109, 343, 280, 112, 217, 39, 239, 195, 305, 286, 225, 123, 323, 276, 2, 45, 347, 119, 241, 310, 167, 227, 141, 264, 126, 72, 212, 300, 20, 221, 76, 275, 308, 322, 335, 14, 111, 139, 340, 325, 34, 92, 272, 26, 46, 289, 249, 201, 113, 142, 200, 252, 229, 226, 199, 37, 103, 114, 350, 162, 36, 0, 152, 171, 342, 66, 259, 1, 188, 328, 130, 244, 7, 346, 336, 96, 248, 55, 331, 223, 153, 64, 177, 194, 314, 255, 267, 299, 173, 185, 169, 294, 271, 146, 18, 143, 213, 220, 86, 42, 138, 245, 52, 329, 124, 192, 197, 27, 198, 47, 61, 145, 242, 222, 180, 307, 122, 62, 304, 5, 91, 189, 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)




[208, 88, 106, 176, 297, 143, 130, 241, 169, 234, 347, 272, 201, 50, 244, 260, 38, 108, 81, 245, 28, 257, 342, 36, 191, 60, 282, 343, 51, 303, 137, 2, 148, 221, 172, 236, 227, 331, 298, 163, 217, 3, 42, 165, 239, 87, 249, 17, 131, 346, 26, 67, 284, 305, 79, 152, 296, 291, 226, 112, 63, 320, 90, 348, 20, 254, 292, 125, 255, 93, 145, 124, 118, 153, 24, 238, 216, 121, 117, 336, 62, 212, 149, 31, 98, 71, 164, 326, 187, 243, 142, 147, 198, 44, 219, 27, 233, 261, 321, 247, 53, 135, 264, 332, 167, 68, 200, 186, 140, 6, 128, 181, 293, 104, 74, 97, 85, 159, 170, 35, 288, 25, 15, 268, 340, 190, 240, 49, 209, 77, 69, 273, 76, 109, 83, 304, 242, 139, 1, 202, 311, 324, 222, 199, 46, 194, 185, 95, 295, 335, 80, 223, 308, 193, 224, 84, 248, 206, 322, 330, 294, 13, 175, 189, 280, 100, 56, 256, 12, 302, 91, 10, 258, 345, 207, 205, 283, 285, 353, 266, 250, 155, 214, 16, 178, 65, 251, 339, 259, 269, 99, 156, 179, 14, 136, 8, 160, 327, 58, 277, 192, 174, 111, 183, 271, 231, 323, 57, 102, 203, 7, 197, 338

[276, 350, 8, 156, 215, 322, 279, 304, 135, 158, 62, 41, 237, 189, 48, 197, 331, 121, 152, 207, 239, 176, 202, 167, 323, 169, 299, 321, 15, 198, 213, 260, 196, 18, 300, 123, 14, 154, 241, 46, 92, 246, 36, 23, 59, 171, 133, 296, 168, 284, 43, 21, 267, 157, 89, 122, 98, 253, 235, 226, 74, 12, 228, 4, 216, 303, 352, 317, 144, 204, 254, 343, 86, 161, 249, 269, 138, 263, 333, 85, 298, 244, 70, 223, 151, 80, 335, 222, 292, 118, 201, 327, 65, 191, 140, 124, 230, 39, 5, 277, 181, 182, 344, 221, 291, 336, 349, 236, 278, 113, 150, 289, 24, 37, 112, 42, 184, 234, 129, 106, 147, 195, 257, 83, 78, 128, 212, 302, 316, 105, 81, 190, 104, 294, 31, 130, 142, 120, 258, 337, 72, 311, 67, 82, 125, 203, 266, 73, 0, 141, 210, 25, 225, 57, 348, 174, 274, 320, 102, 139, 312, 220, 318, 211, 340, 205, 99, 1, 96, 194, 308, 271, 52, 338, 233, 75, 250, 115, 30, 262, 40, 13, 206, 116, 136, 27, 159, 209, 324, 282, 227, 16, 240, 91, 143, 127, 97, 214, 175, 330, 71, 66, 107, 51, 319, 192, 108, 183, 163, 275, 22, 69, 1

[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11, 35]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11, 35, 273]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11, 35, 273, 202]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11, 35, 273, 202, 31]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 76, 188, 135, 44, 302, 140, 228, 330, 11, 35, 273, 202, 31, 152]
[226, 49, 165, 349, 27, 213, 134, 120, 265, 205, 62, 294, 7