# GDP 预测

In [1]:
import pandas as pd
from pandas_datareader import wb

import torch
import torch.nn
import torch.optim

读取数据

In [2]:
countries = ['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL', 'JP', 'SA', 'GB', 'US',]
dat = wb.download(indicator='NY.GDP.PCAP.KD',
        country=countries, start=1970, end=2016)
df = dat.unstack().T
df.index = df.index.droplevel(0).astype(int)
df

country,Brazil,Canada,China,France,Germany,India,Israel,Japan,Saudi Arabia,United Kingdom,United States
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1970,4706.126393,24629.215564,228.317703,20090.770923,19624.749759,365.057383,14476.725344,18435.455076,22133.904924,17934.191423,23309.620946
1971,5108.945626,25262.44135,237.813838,20985.170602,20202.433743,362.767725,14750.563777,19054.841724,25517.184135,18481.731208,23775.276923
1972,5586.683824,26216.591277,240.881889,21739.884077,20970.626433,352.550056,16127.11638,20370.673766,29931.470715,19211.55651,24760.145377
1973,6216.130816,27571.292534,253.714373,22903.302398,21903.403507,355.78821,16352.902703,21825.54372,35393.583566,20422.489067,25908.912802
1974,6617.88359,28080.94072,254.267485,23690.600278,22089.748966,351.708069,16901.495224,21150.496237,39125.445624,19906.842338,25540.501003
1975,6798.096714,28057.047887,271.599476,23298.340492,21980.08789,375.083373,17055.638892,21458.04982,33860.391587,19613.86382,25239.919906
1976,7287.519095,29128.014087,263.230622,24174.746998,23167.059261,372.642613,16732.882702,22146.595979,37905.118963,20189.605653,26347.809282
1977,7443.898645,29783.267924,279.324547,24907.284189,23996.772912,390.63668,16531.778366,22897.185145,38557.214887,20689.657017,27286.251514
1978,7504.29602,30651.632553,307.766195,25811.802598,24740.236527,403.633544,17085.997506,23887.179963,34657.430841,21557.849741,28500.240457
1979,7824.864356,31502.044296,326.768369,26641.929481,25755.657807,373.832253,17686.527449,24985.791175,36676.579817,22344.368377,29082.593778


搭建神经网络

In [3]:
class Net(torch.nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super(Net, self).__init__()
        self.rnn = torch.nn.LSTM(input_size, hidden_size)
        self.fc = torch.nn.Linear(hidden_size, 1)
        
    def forward(self, x):
        x = x[:, :, None]
        x, _ = self.rnn(x)
        x = self.fc(x)
        x = x[:, :, 0]
        return x

net = Net(input_size=1, hidden_size=5)
net

Net(
  (rnn): LSTM(1, 5)
  (fc): Linear(in_features=5, out_features=1, bias=True)
)

训练神经网络

In [4]:
# 数据归一化
df_scaled = df / df.loc[2000]

# 确定训练集和测试集
years = df.index
train_seq_len = sum((years >= 1971) & (years <= 2000))
test_seq_len = sum(years > 2000)
print ('训练集长度 = {}, 测试集长度 = {}'.format(
        train_seq_len, test_seq_len))

# 确定训练使用的特征和标签
inputs = torch.tensor(df_scaled.iloc[:-1].values, dtype=torch.float32)
labels = torch.tensor(df_scaled.iloc[1:].values, dtype=torch.float32)

# 训练网络
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters())
for step in range(10001):
    if step:
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    
    preds = net(inputs)
    train_preds = preds[:train_seq_len]
    train_labels = labels[:train_seq_len]
    train_loss = criterion(train_preds, train_labels)
    
    test_preds = preds[-test_seq_len]
    test_labels = labels[-test_seq_len]
    test_loss = criterion(test_preds, test_labels)
    
    if step % 500 == 0:
        print ('第{}次迭代: loss (训练集) = {}, loss (测试集) = {}'.format(
                step, train_loss, test_loss))

训练集长度 = 30, 测试集长度 = 16
第0次迭代: loss (训练集) = 1.4283555746078491, loss (测试集) = 1.9456963539123535
第500次迭代: loss (训练集) = 0.051086507737636566, loss (测试集) = 0.010078947991132736
第1000次迭代: loss (训练集) = 0.01944635808467865, loss (测试集) = 0.002939914120361209
第1500次迭代: loss (训练集) = 0.007834377698600292, loss (测试集) = 0.0010365095222368836
第2000次迭代: loss (训练集) = 0.004264211747795343, loss (测试集) = 0.0005575661198236048
第2500次迭代: loss (训练集) = 0.002953553106635809, loss (测试集) = 0.0005140299326740205
第3000次迭代: loss (训练集) = 0.0023847392294555902, loss (测试集) = 0.0005197693244554102
第3500次迭代: loss (训练集) = 0.002063404768705368, loss (测试集) = 0.0005044733406975865
第4000次迭代: loss (训练集) = 0.001874488778412342, loss (测试集) = 0.0004906203248538077
第4500次迭代: loss (训练集) = 0.0017608855850994587, loss (测试集) = 0.0004924260429106653
第5000次迭代: loss (训练集) = 0.001682320493273437, loss (测试集) = 0.0005076289526186883
第5500次迭代: loss (训练集) = 0.0016172596951946616, loss (测试集) = 0.0005304039223119617
第6000次迭代: loss (训练集) = 0.0

预测

In [5]:
preds = net(inputs)
df_pred_scaled = pd.DataFrame(preds.detach().numpy(),
        index=years[1:], columns=df.columns)
df_pred = df_pred_scaled * df.loc[2000]
df_pred.loc[2001:]

country,Brazil,Canada,China,France,Germany,India,Israel,Japan,Saudi Arabia,United Kingdom,United States
year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2001,8728.474609,44115.8125,1803.890381,38758.660156,38115.324219,772.356262,27742.117188,41971.28125,17978.492188,35744.425781,45331.175781
2002,8791.182617,44121.328125,1922.057007,39006.152344,38580.707031,783.310364,27209.820312,42332.28125,17640.318359,36342.664062,45034.355469
2003,8888.078125,44490.3125,2051.746582,38855.230469,38287.113281,793.61438,26333.099609,42042.03125,16624.966797,36820.015625,45028.613281
2004,8859.866211,44770.628906,2198.029053,38760.089844,37806.5,837.694824,26041.751953,42403.21875,17823.111328,37667.347656,45898.242188
2005,9210.673828,45581.945312,2339.192139,39543.230469,38244.601562,887.680359,26928.919922,43388.410156,19140.501953,38211.289062,47274.554688
2006,9422.847656,46537.699219,2496.956055,39985.117188,38723.804688,939.49292,27781.509766,44090.53125,19460.792969,38858.921875,48266.789062
2007,9592.635742,47181.664062,2679.909912,40467.410156,40128.886719,985.826965,28578.460938,44437.285156,19059.839844,39318.265625,48683.855469
2008,9980.925781,47384.859375,2889.713379,41026.304688,41431.019531,1033.875,29471.3125,44868.574219,18686.330078,39690.921875,48707.433594
2009,10308.732422,47094.582031,2964.640137,40746.019531,41567.589844,1020.524536,29582.595703,44243.945312,19352.714844,39017.917969,47921.152344
2010,10042.78418,45130.597656,3082.189941,39192.152344,38951.582031,1063.351318,28979.738281,41773.0625,18638.841797,37010.519531,46168.046875
