In [38]:
import numpy as np
import pandas as pd
import torch
import datetime
import plotly as py
import plotly.graph_objs as go
from sklearn import preprocessing

In [33]:
features = pd.read_csv('temps.csv')
print(features.head())
print('数据维度: ', features.shape)

   year  month  day  week  temp_2  temp_1  average  actual  friend
0  2016      1    1   Fri      45      45     45.6      45      29
1  2016      1    2   Sat      44      45     45.7      44      61
2  2016      1    3   Sun      45      44     45.8      41      56
3  2016      1    4   Mon      44      41     45.9      40      53
4  2016      1    5  Tues      41      40     46.0      44      41
数据维度:  (348, 9)


In [34]:
# 分别得到年，月，日
years = features['year']
months = features['month']
days = features['day']

# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
print(dates)
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
print(dates)

['2016-1-1', '2016-1-2', '2016-1-3', '2016-1-4', '2016-1-5', '2016-1-6', '2016-1-7', '2016-1-8', '2016-1-9', '2016-1-10', '2016-1-11', '2016-1-12', '2016-1-13', '2016-1-14', '2016-1-15', '2016-1-16', '2016-1-17', '2016-1-18', '2016-1-19', '2016-1-20', '2016-1-21', '2016-1-22', '2016-1-23', '2016-1-24', '2016-1-25', '2016-1-26', '2016-1-27', '2016-1-28', '2016-1-29', '2016-1-30', '2016-1-31', '2016-2-1', '2016-2-2', '2016-2-3', '2016-2-4', '2016-2-5', '2016-2-6', '2016-2-7', '2016-2-8', '2016-2-9', '2016-2-10', '2016-2-11', '2016-2-12', '2016-2-15', '2016-2-16', '2016-2-17', '2016-2-18', '2016-2-19', '2016-2-20', '2016-2-21', '2016-2-22', '2016-2-23', '2016-2-24', '2016-2-25', '2016-2-26', '2016-2-27', '2016-2-28', '2016-3-1', '2016-3-2', '2016-3-3', '2016-3-4', '2016-3-5', '2016-3-6', '2016-3-7', '2016-3-8', '2016-3-9', '2016-3-10', '2016-3-11', '2016-3-12', '2016-3-13', '2016-3-14', '2016-3-15', '2016-3-16', '2016-3-17', '2016-3-18', '2016-3-19', '2016-3-20', '2016-3-21', '2016-3-22',

In [35]:
pyplt = py.offline.iplot

name = ['actual', 'temp_1', 'temp_2', 'friend']
trace1 = go.Scatter(
    x=dates,
    y=features[name[0]],
    name=name[0]
)
trace2 = go.Scatter(
    x=dates,
    y=features[name[1]],
    name=name[1]
)
trace3 = go.Scatter(
    x=dates,
    y=features[name[2]],
    name=name[2]
)
trace4 = go.Scatter(
    x=dates,
    y=features[name[3]],
    name=name[3]
)

layout = dict(
    title='气温变化',
    xaxis=dict(title='Date'),  # x轴标题
    yaxis=dict(title='Temp')  # y轴标题
)
data = [trace1, trace2, trace3, trace4]
fig = dict(data=data, layout=layout)
pyplt(fig)

In [36]:
# week转换为独热编码
features = pd.get_dummies(features, columns=['week'])
features.head()

Unnamed: 0,year,month,day,temp_2,temp_1,average,actual,friend,week_Fri,week_Mon,week_Sat,week_Sun,week_Thurs,week_Tues,week_Wed
0,2016,1,1,45,45,45.6,45,29,1,0,0,0,0,0,0
1,2016,1,2,44,45,45.7,44,61,0,0,1,0,0,0,0
2,2016,1,3,45,44,45.8,41,56,0,0,0,1,0,0,0
3,2016,1,4,44,41,45.9,40,53,0,1,0,0,0,0,0
4,2016,1,5,41,40,46.0,44,41,0,0,0,0,0,1,0


In [37]:
# 标签
labels = np.array(features['actual'])

# 在特征中去掉标签
features = features.drop('actual', axis = 1)

# 名字单独保存一下，以备后患
feature_list = list(features.columns)

# 转换成合适的格式
features = np.array(features)

In [39]:
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features

array([[ 0.        , -1.5678393 , -1.65682171, ..., -0.40482045,
        -0.41913682, -0.40482045],
       [ 0.        , -1.5678393 , -1.54267126, ..., -0.40482045,
        -0.41913682, -0.40482045],
       [ 0.        , -1.5678393 , -1.4285208 , ..., -0.40482045,
        -0.41913682, -0.40482045],
       ...,
       [ 0.        ,  1.5810006 ,  1.53939107, ...,  2.47023092,
        -0.41913682, -0.40482045],
       [ 0.        ,  1.5810006 ,  1.65354153, ..., -0.40482045,
        -0.41913682, -0.40482045],
       [ 0.        ,  1.5810006 ,  1.76769198, ..., -0.40482045,
        -0.41913682, -0.40482045]])

In [40]:
input_dim = input_features.shape[1]
hidden_dim = 128
output_dim = 1
batch_size = 16
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_dim, hidden_dim),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_dim, output_dim),
)
cost = torch.nn.MSELoss(reduction='mean') # 默认求平均loss返回
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)

In [75]:
# 训练网络
losses = []
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)
        xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)
        yy = torch.tensor(labels[start:end].reshape(end - start, 1), dtype = torch.float, requires_grad = True)
        prediction = my_nn(xx)
        loss = cost(prediction, yy)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        batch_loss.append(loss.data.numpy())

    # 打印损失
    if i % 100==0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))

0 28.63597
100 18.717802
200 17.247545
300 16.015348
400 14.826138
500 13.67918
600 12.54874
700 11.442031
800 10.355056
900 9.272189


In [86]:
x = torch.tensor(input_features, dtype = torch.float)
predict = my_nn(x).data.numpy().reshape(-1)

trace1 = go.Scatter(
    x=dates,
    y=predict,
    name='预测值'
)
trace2 = go.Scatter(
    x=dates,
    y=labels,
    name='真实值'
)

layout = dict(
    title='预测值和真实值',
    xaxis=dict(title='Date'),  # x轴标题
    yaxis=dict(title='Temp')  # y轴标题
)
data = [trace1, trace2]
fig = dict(data=data, layout=layout)
pyplt(fig)

In [87]:
print(predict)

[47.289417 44.482864 43.605755 42.137585 45.674576 49.124    46.541767
 48.123863 48.632187 48.232517 49.861004 46.970264 53.296684 51.902046
 51.06278  52.79298  51.93604  50.705082 53.66526  51.02204  50.48243
 53.963196 51.89928  47.74838  52.841267 55.104916 57.626305 56.588326
 54.193398 49.613003 48.026802 44.28821  50.924366 51.799328 47.897266
 50.11854  51.876423 52.662437 49.72248  57.908146 56.233147 55.105816
 54.63738  54.24248  55.823017 56.358868 55.35517  53.958138 52.42529
 50.675304 52.57439  56.593338 60.23627  58.84674  59.842354 57.423958
 53.22891  57.522728 54.385788 57.473118 58.13455  60.487606 62.187214
 55.490166 54.39579  56.91364  55.616573 56.298004 55.8821   53.72921
 52.074787 50.890713 53.563934 56.15302  60.466526 61.686222 56.133186
 55.934906 57.132004 55.411293 55.031372 55.92523  55.381615 51.17492
 56.514984 64.265945 64.692245 70.55899  71.90258  63.217995 66.25675
 66.84089  58.407642 62.944088 73.95231  75.976204 68.274025 63.369152
 60.73924  

In [79]:
dates

[datetime.datetime(2016, 1, 1, 0, 0),
 datetime.datetime(2016, 1, 2, 0, 0),
 datetime.datetime(2016, 1, 3, 0, 0),
 datetime.datetime(2016, 1, 4, 0, 0),
 datetime.datetime(2016, 1, 5, 0, 0),
 datetime.datetime(2016, 1, 6, 0, 0),
 datetime.datetime(2016, 1, 7, 0, 0),
 datetime.datetime(2016, 1, 8, 0, 0),
 datetime.datetime(2016, 1, 9, 0, 0),
 datetime.datetime(2016, 1, 10, 0, 0),
 datetime.datetime(2016, 1, 11, 0, 0),
 datetime.datetime(2016, 1, 12, 0, 0),
 datetime.datetime(2016, 1, 13, 0, 0),
 datetime.datetime(2016, 1, 14, 0, 0),
 datetime.datetime(2016, 1, 15, 0, 0),
 datetime.datetime(2016, 1, 16, 0, 0),
 datetime.datetime(2016, 1, 17, 0, 0),
 datetime.datetime(2016, 1, 18, 0, 0),
 datetime.datetime(2016, 1, 19, 0, 0),
 datetime.datetime(2016, 1, 20, 0, 0),
 datetime.datetime(2016, 1, 21, 0, 0),
 datetime.datetime(2016, 1, 22, 0, 0),
 datetime.datetime(2016, 1, 23, 0, 0),
 datetime.datetime(2016, 1, 24, 0, 0),
 datetime.datetime(2016, 1, 25, 0, 0),
 datetime.datetime(2016, 1, 26, 0,

In [83]:
print(labels)

[45 44 41 40 44 51 45 48 50 52 45 49 55 49 48 54 50 54 48 52 52 57 48 51
 54 56 57 56 52 48 47 46 51 49 49 53 49 51 57 62 56 55 58 55 56 57 53 51
 53 51 51 60 59 61 60 57 53 58 55 59 57 64 60 53 54 55 56 55 52 54 49 51
 53 58 63 61 55 56 57 53 54 57 59 51 56 64 68 73 71 63 69 60 57 68 77 76
 66 59 58 60 59 59 60 68 77 89 81 81 73 64 65 55 59 60 61 64 61 68 77 87
 74 60 68 77 82 63 67 75 81 77 82 65 57 60 71 64 63 66 59 66 65 66 66 65
 64 64 64 71 79 75 71 80 81 92 86 85 67 65 67 65 70 66 60 67 71 67 65 70
 76 73 75 68 69 71 78 85 79 74 73 76 76 71 68 69 76 68 74 71 74 74 77 75
 77 76 72 80 73 78 82 81 71 75 80 85 79 83 85 88 76 73 77 73 75 80 79 72
 72 73 72 76 80 87 90 83 84 81 79 75 70 67 68 68 68 67 72 74 77 70 74 75
 79 71 75 68 69 71 67 68 67 64 67 76 77 69 68 66 67 63 65 61 63 66 63 64
 68 57 60 62 66 60 60 62 60 60 61 58 62 59 62 62 61 65 58 60 65 68 59 57
 57 65 65 58 61 63 71 65 64 63 59 55 57 55 50 52 55 57 55 54 54 49 52 52
 53 48 52 52 52 46 50 49 46 40 42 40 41 36 44 44 43