# 梯度下降法
## 理论
### 确定损失函数（$J(\Theta)$ ） 确定步长（$\alpha$）
### 迭代公式：
$\Theta^1 = \Theta^0 - \alpha\Delta J(\Theta)$
（单自变量中：$\Delta J(\Theta)=J'(\theta^0)$）
#### 与牛顿法区别：
牛顿法计算量更大，在复杂的函数中较为困难，梯度下降法更适合复杂的函数。

## 用梯度下降法解决线性回归问题
将均方误差（MSE）定为损失函数,
$MSE=Var(\hat{Y}-Y)=\frac{1}{n}\sum_{i=1}^{n}(\hat{y_{i}}-y_{i} )^2,\hat{y_{i}}=\hat{\beta _{0}}+\hat{\beta _{1}}x_{i}$

### 根据$Y=\beta_{0}+\beta_{1}X+\varepsilon $生成线性回归数据

In [2]:
import numpy as np
import pandas as pd
import matplotlib as mpl
np.random.seed(123)
beta = np.random.uniform(-5,5,5)
beta0 = beta[0]
beta1 = beta[1:]
print(beta0,beta1)

1.9646918559786162 [-2.13860665 -2.73148546  0.51314769  2.1946897 ]


In [10]:
def Beta(X,b,k):#X为X向量，b为β向量，k为变量数
    X = np.vstack((1,X.reshape(k,1)))
    Y = np.dot(b.reshape(1,k+1),X)
    return Y[0][0]

In [8]:
np.random.seed(123)
x = np.random.uniform(0,100,80)
x = x.reshape(4,20)
print(x)

[[69.64691856 28.6139335  22.68514536 55.13147691 71.94689698 42.31064601
  98.07641984 68.48297386 48.09319015 39.21175182 34.31780162 72.90497074
  43.85722447  5.96778966 39.80442553 73.79954057 18.24917305 17.54517561
  53.15513738 53.18275871]
 [63.44009586 84.94317941 72.44553249 61.10235107 72.24433826 32.29589139
  36.17886556 22.82632309 29.37140464 63.09761239  9.21049399 43.37011727
  43.08627633 49.36850977 42.58302903 31.2261223  42.6351307  89.33891631
  94.41600182 50.18366759]
 [62.39529518 11.56183951 31.72854818 41.4826212  86.63091579 25.04553654
  48.30342643 98.55597856 51.94851193 61.28945258 12.0628666  82.63408005
  60.30601284 54.50680065 34.27638338 30.4120789  41.7022211  68.13007658
  87.54568418 51.04223375]
 [66.9313783  58.59365526 62.49035021 67.4689051  84.23424376  8.31949883
  76.36828414 24.36663745 19.42229606 57.24569575  9.57125166 88.53268263
  62.72489721 72.34163582  1.61292067 59.44318794 55.67851924 15.89596441
  15.30705151 69.55295288]]


In [11]:
Y = []
for i in range(20):
    X = np.array([x[0:,i]]).reshape(4,1)
    Y.append(Beta(X,beta,4))
print(Y)
np.random.seed(123)
epsilon = np.random.normal(0,1,20)
Y = Y + epsilon
print(Y)
alpha = 0.01

[-141.35676299850888, -156.72249366630322, -91.00547076081516, -113.48001286961762, -119.91329895378695, -145.6261173489736, -114.21275908336358, -102.79224019587546, -111.83211727419018, -97.15695266024953, -69.38995781743454, -35.709954476410786, -40.9106206526629, 41.09005212236997, -178.34753477389836, -95.09195532889615, -9.923879724955938, -209.73798789421008, -291.09108109932566, -70.00891466104719]
[-142.4423936  -155.72514822  -90.72249226 -114.98630758 -120.49189921
 -143.97468081 -116.63943833 -103.22115282 -110.56618102  -98.02369306
  -70.06884397  -35.80466345  -39.41923103   40.45115013 -178.79151673
  -95.5263066    -7.71794964 -207.55120181 -290.0870272   -69.62272826]


### 求损失函数在各自变量上的偏导数$\frac{\partial J(\beta )}{\partial\beta_{i}} =\left\{\begin{matrix} 
  \frac{2}{n}\sum_{i=1}^{n}  (\hat{\beta _{0}}+\sum_{j=1}^{k} \hat{\beta _{j}}x_{ij}-y_{i} )(对\beta _{0}求偏导) \\  
  \frac{2}{n}\sum_{i=1}^{n}  (\hat{\beta _{0}}+\sum_{j=1}^{k} \hat{\beta _{j}}x_{ij}-y_{i} )xij(对\beta _{j}求偏导)
\end{matrix}\right.$

In [6]:
def partical(beta_hat,n,k):#βhat为估计的β向量，n为样本量，k为变量数
    sum0 = []
    for i in range(n):
        sum0.append(Beta(x[0:,i],beta_hat,k) - Y[i])
    vector0 = sum(sum0)
    vector = []
    for j in range(k):
        vector.append(sum(np.asarray(sum0) * x[j,0:]))
    vector = 2/n * np.hstack((vector0,vector))
    return vector

In [7]:
def gradient(origin, accuracy = 1e-5, alpha = 0.01):
    theta = origin - alpha * partical(origin, 20, 4)
    delta = max(abs(partical(theta, 20, 4)))
    times = 0
    while delta >= accuracy:
        times += 1
        theta = theta - alpha * partical(theta, 20, 4)
        print(theta)
        delta = alpha * max(abs(partical(theta, 20, 4)))
    return theta

In [12]:
gradient(np.array([2,-2,-2,1,2]),1e-5,1e-5)

[ 1.99785157 -2.10425358 -2.12555408  0.87436561  1.89554085]
[ 1.99710994 -2.1392733  -2.17002367  0.83059074  1.86134516]
[ 1.99652677 -2.1661553  -2.20575792  0.79589324  1.83572137]
[ 1.99606741 -2.18668803 -2.23465655  0.76829252  1.81678052]
[ 1.99570481 -2.20226834 -2.25820358  0.74624148  1.80304731]
[ 1.99541783 -2.21398722 -2.27755834  0.7285316   1.79336951]
[ 1.99518996 -2.22269675 -2.29362653  0.71421883  1.78684727]
[ 1.99500831 -2.2290624  -2.30711576  0.70256568  1.78277785]
[ 1.99486281 -2.23360385 -2.31857887  0.69299605  1.78061248]
[ 1.99474558 -2.23672695 -2.32844785  0.68505986  1.77992267]
[ 1.99465048 -2.23874862 -2.33706027  0.67840548  1.78037386]
[ 1.99457271 -2.23991631 -2.34467996  0.67275818  1.78170484]
[ 1.99450851 -2.24042321 -2.35151312  0.6679033   1.78371174]
[ 1.99445496 -2.24042014 -2.35772097  0.6636731   1.78623544]
[ 1.99440976 -2.24002482 -2.36342958  0.65993649  1.78915178]
[ 1.99437113 -2.23932912 -2.36873755  0.656591    1.79236391]
[ 1.9943

[ 1.99317434 -2.13910794 -2.66361779  0.51708801  2.12228788]
[ 1.99317297 -2.13906559 -2.66410941  0.51694008  2.12289685]
[ 1.99317163 -2.13902522 -2.66459553  0.51679497  2.12349923]
[ 1.99317032 -2.13898681 -2.66507621  0.51665265  2.12409507]
[ 1.99316903 -2.13895032 -2.66555153  0.51651307  2.12468447]
[ 1.99316777 -2.13891572 -2.66602154  0.5163762   2.12526749]
[ 1.99316654 -2.13888298 -2.66648632  0.516242    2.1258442 ]
[ 1.99316533 -2.13885206 -2.66694592  0.51611043  2.12641468]
[ 1.99316415 -2.13882294 -2.66740041  0.51598146  2.126979  ]
[ 1.993163   -2.13879558 -2.66784986  0.51585506  2.12753723]
[ 1.99316186 -2.13876995 -2.66829431  0.51573118  2.12808943]
[ 1.99316076 -2.13874601 -2.66873383  0.5156098   2.12863567]
[ 1.99315968 -2.13872376 -2.66916849  0.51549088  2.12917603]
[ 1.99315862 -2.13870314 -2.66959834  0.51537438  2.12971057]
[ 1.99315759 -2.13868413 -2.67002344  0.51526028  2.13023935]
[ 1.99315658 -2.13866671 -2.67044385  0.51514854  2.13076244]
[ 1.9931

[ 1.99321641 -2.14581587 -2.70496194  0.51386626  2.17356057]
[ 1.9932175  -2.14586265 -2.70502949  0.51389792  2.17363954]
[ 1.99321859 -2.14590931 -2.70509648  0.51392962  2.17371779]
[ 1.99321968 -2.14595584 -2.70516293  0.51396136  2.17379533]
[ 1.99322078 -2.14600224 -2.70522883  0.51399314  2.17387217]
[ 1.99322188 -2.14604852 -2.7052942   0.51402496  2.17394831]
[ 1.99322298 -2.14609468 -2.70535904  0.51405682  2.17402376]
[ 1.9932241  -2.14614071 -2.70542336  0.5140887   2.17409853]
[ 1.99322521 -2.14618661 -2.70548715  0.51412062  2.17417262]
[ 1.99322633 -2.14623237 -2.70555043  0.51415256  2.17424605]
[ 1.99322745 -2.14627802 -2.7056132   0.51418453  2.17431881]
[ 1.99322858 -2.14632353 -2.70567546  0.51421652  2.17439092]
[ 1.99322971 -2.1463689  -2.70573723  0.51424853  2.17446239]
[ 1.99323085 -2.14641415 -2.7057985   0.51428056  2.17453321]
[ 1.99323199 -2.14645927 -2.70585928  0.5143126   2.17460339]
[ 1.99323313 -2.14650425 -2.70591957  0.51434466  2.17467295]
[ 1.9932

[ 1.99349217 -2.15247782 -2.7120403   0.51938913  2.18118795]
[ 1.99349374 -2.15249884 -2.71205732  0.51940921  2.18120386]
[ 1.99349532 -2.15251976 -2.71207423  0.51942922  2.18121967]
[ 1.99349689 -2.15254058 -2.71209104  0.51944915  2.18123536]
[ 1.99349847 -2.1525613  -2.71210775  0.51946899  2.18125094]
[ 1.99350004 -2.15258192 -2.71212437  0.51948876  2.18126642]
[ 1.99350162 -2.15260245 -2.71214089  0.51950846  2.18128179]
[ 1.9935032  -2.15262289 -2.71215731  0.51952807  2.18129706]
[ 1.99350478 -2.15264322 -2.71217364  0.5195476   2.18131222]
[ 1.99350636 -2.15266347 -2.71218987  0.51956706  2.18132728]
[ 1.99350794 -2.15268361 -2.71220601  0.51958643  2.18134224]
[ 1.99350952 -2.15270367 -2.71222205  0.51960573  2.18135709]
[ 1.99351111 -2.15272362 -2.712238    0.51962496  2.18137184]
[ 1.99351269 -2.15274349 -2.71225386  0.5196441   2.18138649]
[ 1.99351428 -2.15276326 -2.71226963  0.51966317  2.18140104]
[ 1.99351586 -2.15278294 -2.7122853   0.51968216  2.1814155 ]
[ 1.9935

array([ 1.9937552 , -2.15483394, -2.71384055,  0.52172727,  2.18277257])

In [None]:
1.9646918559786162 -2.13860665 -2.73148546  0.51314769  2.1946897 