In [1]:
import numpy as np
import matplotlib
matplotlib.use('nbagg')

import matplotlib.pyplot as plt

# %matplotlib inline
# import mpld3
# mpld3.enable_notebook()

# plt.ioff()

In [2]:
# make some data and plot it
N = 100
X = np.linspace(0, 6 * np.pi, N)
Y = np.sin(X)

In [3]:
plt.plot(X, Y)
plt.show()

<IPython.core.display.Javascript object>

In [4]:
#xrange is used in python2 -> use range in python3
#np.ones returns a numpy array [] converts it into a python list
# ** and * is not the same
# ** is power or ^ whereas * multiplies each element of numpy array with that number
# ** can be interchanged with np.__pow__(<power>)
# numpy.vstack() function is used to stack the sequence of input arrays vertically to make a single array.

def make_poly(X, deg) :
    n = len(X)
    data = [np.ones(n)]
    for d in range(deg) :
        data.append(X**(d+1))
        
    return np.vstack(data).T

def fit(X, Y) :
    return np.linalg.solve(X.T.dot(X), X.T.dot(Y))

# np.random.choice(a, b) -> a is the input list/array from which element is to be chosen and b is the no of choices we want
def fit_and_display(X, Y , sample, deg) :
    N = len(X)
    train_idx = np.random.choice(N, sample)
    Xtrain = X[train_idx]
    Ytrain = Y[train_idx]
    
    plt.scatter(Xtrain, Ytrain, s = 10)
    plt.show()
    
    #fit polynomial
    Xtrain_poly = make_poly(Xtrain, deg)
    w = fit(Xtrain_poly, Ytrain)
    
    #display the polynomial
    X_poly = make_poly(X, deg)
    Yhat = X_poly.dot(w)
    plt.plot(X, Y)
    plt.plot(X, Yhat)
    plt.scatter(Xtrain, Ytrain, s = 10)
    plt.title("deg = %d" % deg)
    plt.show()

    

In [5]:
#mean-squared error
def get_mse(Y, yhat) :
    d1 = Y - yhat
    d2 = Y - Y.mean()
    r2 = d1.dot(d1)/ d2.dot(d2)
    
    d = Y - yhat
    
    return d.dot(d) / len(d)
#     return r2

In [6]:
def plot_train_vs_test_curves(X, Y, sample = 20, max_deg = 20) :
    N = len(X)
    train_idx = np.random.choice(N, sample)
    Xtrain = X[train_idx]
    Ytrain = Y[train_idx]
    
    test_idx = [idx for idx in range(N) if idx not in train_idx]
    Xtest = X[test_idx]
    Ytest = Y[test_idx]
    
    mse_trains = []
    mse_tests = []
    
    for deg in range(max_deg + 1) :
        Xtrain_poly = make_poly(Xtrain, deg)
        w = fit(Xtrain_poly, Ytrain)
        yhat_train = Xtrain_poly.dot(w)
        mse_train = get_mse(Ytrain, yhat_train)
        
        Xtest_poly = make_poly(Xtest, deg)
        Yhat_test = Xtest_poly.dot(w)
        mse_test = get_mse(Ytest, Yhat_test)
        
        mse_trains.append(mse_train)
        mse_tests.append(mse_test)
        
    plt.figure()
    plt.plot(mse_trains, label="train mse")
    plt.plot(mse_tests, label="test mse")
    plt.legend()
    plt.show()
    
    plt.plot(mse_trains, label="train mse")
    plt.legend()
    plt.show()
    
    print('mse_trains\n', mse_trains)
    print('mse_tests\n', mse_tests)

In [7]:
if __name__ == "__main__":
    # make up some data and plot it
    N = 100
    X = np.linspace(0, 6*np.pi, N)
    Y = np.sin(X)

    plt.plot(X, Y)
    plt.show()

    for deg in (5, 6, 7, 8, 9):
        fit_and_display(X, Y, 10, deg)
    plot_train_vs_test_curves(X, Y, 90)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

mse_trains
 [0.5211248526732347, 0.4835630451817415, 0.4830298857551775, 0.4439574518849414, 0.4438351712587186, 0.43997137753369586, 0.43281421066838166, 0.14409351733955708, 0.14341331287602424, 0.01232433046102077, 0.010599158345989486, 0.06920918638194978, 0.0009437390682696874, 0.000954851010131731, 0.0009624324281845056, 8.026775482815565e-05, 0.007565143395510876, 0.0004960465469109896, 6.459594030578868e-05, 1.715742215803989e-05, 0.0001591058039616367]
mse_tests
 [0.47947805906914603, 0.44868971549466796, 0.44585794050065175, 0.39278585546436734, 0.3908387203943995, 0.41568177553835073, 0.4520047033084811, 0.18973372026860014, 0.1959362824569062, 0.030952067312316024, 0.05239417410747421, 0.4640069379734253, 0.002763845779088974, 0.002710878400836792, 0.002804657527741837, 0.0010575566022845574, 0.11842602796498347, 0.0022605032637840815, 0.00010941668580148945, 0.00010619241169424912, 0.00036063689208558217]
