In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

In [2]:
class LinearRegression:
    def __init__(self):
        #系数coef
        self.coef = None
        #截距interception
        self.interception_ = None
        #向量，私有变量
        self._theta = None
        
    def fit_normal(self, X_train, y_train):
        #矩阵的shape属性可以获取矩阵的形状（例如二维数组的行列），获取的结果是一个元组
        assert X_train.shape[0] == y_train.shape[0]
        #hstack():按照指定的轴对数组序列进行联结
        #多件一个行数等于X_train的行数，元素全为1的列向量，之后和X_train拼在一起
        X_b = np.hstack([np.ones((len(X_train), 1)), X_train])
        #np.linalg.inv():求逆
        #套正规方程解的公式计算
        self._theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y_train);
        #截距就是第一个theta
        self.interception_ = self._theta[0]
        #然后系数就是theta向量
        self.coef_ = self._theta[1:]
        
        return self
    
    def predict(self, X_predict):
        assert self.interception_ is not None and self.coef_ is not None
        assert X_predict.shape[1] == len(self.coef_)
        X_b = np.hstack([np.ones((len(X_train), 1)), X_train])
        #
        return X_b.dot(self._theta)
    
    def __repr__(self):
        return "LinearRegression()"

In [3]:
boston = datasets.load_boston()

X = boston.data
y = boston.target

X = X[y < 50.0]
y = y[y < 50.0]

In [4]:
X.shape

(490, 13)

In [5]:
#随机划分训练集和测试集
from sklearn.model_selection import train_test_split
X_train, X_text, y_train, y_test = train_test_split(X,y)

In [6]:
reg = LinearRegression()
reg.fit_normal(X_train, y_train)

LinearRegression()

In [7]:
reg.coef_

array([-9.70686499e-02,  3.49375709e-02, -6.89845603e-02,  1.94842291e-01,
       -1.51961777e+01,  3.33774758e+00, -1.70026248e-02, -1.36479945e+00,
        2.58651943e-01, -1.26625140e-02, -8.23896645e-01,  8.72810866e-03,
       -3.50910681e-01])

In [8]:
reg.interception_

35.79620574520362

**很感动，终于有数据出来而不是报错了**

In [9]:
reg.predict(X_train)

array([16.0663101 , 21.12732624, 17.08301079, 23.17817981, 20.21418577,
       21.98188016, 28.16551667, 19.99395113, 28.74070026, 24.76445871,
       18.30292127, 22.72052518, 21.15594707, 23.14987531, 14.53505178,
       19.68334059, 22.23588108, 23.39381066, 20.92839986, 22.44611004,
       20.73420469, 27.26216145, 17.78279821, 20.49238406, 12.78902095,
       17.44126591, 18.41068371, 23.94397758, 31.35298938, 24.44228228,
       22.28389917, 12.89334452, 18.69479829, 15.85600857, 16.55933285,
       15.06417158, 21.49284462, 26.71923103, 30.77638866,  8.18173596,
       10.7809513 , 17.84894327, 16.87162456, 20.90128172, 20.88926711,
       34.36594908, 27.08472543, 24.02199469, 20.4641716 , 22.88359358,
       14.05721137, 25.89977421, 33.43108635, 26.38294452, 27.8201348 ,
       13.63484739, 24.75917883, 22.45135256, 24.66961684, 23.86205159,
       25.52713977, 18.87513544, 17.18527534, 11.33840595, 21.7857998 ,
       29.30362774, 14.53849681, 21.11084038, 13.74370615, 25.21