In [18]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

In [2]:
#定义神经网络
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, (h_n,c_n) = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])# 此处的-1说明我们只取RNN最后输出的那个hn
        return out

In [3]:
# 定义超参数
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

In [10]:
df = pd.read_csv(r'e:\learning\code\data\furnace_date.csv')

In [11]:
df.dropna(axis =0 ,inplace =True)
df.set_index('datetime',inplace =True)
df  = df.astype('float')
df

Unnamed: 0_level_0,pi459b,fiq457b,fiq458b,fiq459b,fiq456b,ficq453b,pi461b,ti403b,fiq460b,ficq452b,...,number,label,ficq452b_set,ficq451b_set,or1_cur_b3_set,or2_cur_b3_set,or3_cur_b3_set,mr1_cur_b3_set,mr2_cur_b3_set,ir_cur_b3_set
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021/12/13 0:01,0.384,9.927,6.969,98.062,148.120,-1.314,0.619,39.979,62.181,5453.152,...,0.0,4002.0,5468.391,1876.136,1287.896,1257.864,1287.896,1257.880,1205.852,1119.029
2021/12/13 0:02,0.384,9.926,6.973,98.059,148.114,-1.314,0.620,39.949,58.438,5496.660,...,0.0,4002.0,5468.232,1876.303,1288.380,1258.331,1288.380,1258.347,1206.285,1119.379
2021/12/13 0:03,0.383,9.926,6.973,98.055,148.130,-1.310,0.619,39.931,54.896,5458.606,...,0.0,4002.0,5468.074,1876.469,1288.863,1258.798,1288.863,1258.813,1206.718,1119.729
2021/12/13 0:04,0.384,9.926,6.969,97.832,148.117,-1.310,0.618,39.912,56.428,5509.111,...,0.0,4002.0,5467.925,1876.636,1289.319,1259.231,1289.319,1259.246,1207.121,1120.079
2021/12/13 0:05,0.383,9.926,6.973,98.393,148.120,-1.314,0.619,39.924,59.238,5481.219,...,0.0,4002.0,5467.775,1876.803,1289.786,1259.664,1289.786,1259.679,1207.505,1120.429
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022/2/15 19:15,0.391,9.933,6.637,95.291,155.760,-1.282,0.624,40.764,58.822,2933.157,...,13.0,4002.0,2922.745,1198.167,1588.414,1475.414,1588.414,1475.414,1395.614,1211.014
2022/2/15 19:16,0.391,9.933,6.641,95.521,155.763,-1.285,0.624,40.737,54.402,2941.973,...,13.0,4002.0,2922.735,1198.167,1588.416,1475.416,1588.416,1475.416,1395.616,1211.016
2022/2/15 19:17,0.391,9.933,6.641,95.189,155.993,-1.288,0.623,40.723,62.171,2917.796,...,13.0,4002.0,2922.725,1198.167,1588.417,1475.417,1588.417,1475.417,1395.617,1211.017
2022/2/15 19:18,0.390,9.933,6.636,95.523,155.769,-1.282,0.623,40.710,60.632,2907.065,...,13.0,4002.0,2922.715,1198.167,1588.419,1475.419,1588.419,1475.419,1395.619,1211.019


In [13]:
target = ['or1_cur_b3_set','or2_cur_b3_set','mr2_cur_b3_set','ir_cur_b3_set']

In [14]:
def fil_fun(x):
    if x['number'].unique() in [0,6,11,13]:  #这里的12炉作为测试数据
        x[target]=x[target].shift(-30)  #将数据进行偏移
        return x
df2 = df.groupby('number').apply(fil_fun)
df2.drop(['number'],axis=1,inplace =True)
df2.reset_index(inplace=True)
df2.dropna(axis = 0,inplace=True)
df2

Unnamed: 0,number,datetime,pi459b,fiq457b,fiq458b,fiq459b,fiq456b,ficq453b,pi461b,ti403b,...,dcs,label,ficq452b_set,ficq451b_set,or1_cur_b3_set,or2_cur_b3_set,or3_cur_b3_set,mr1_cur_b3_set,mr2_cur_b3_set,ir_cur_b3_set
0,0.0,2021/12/13 0:01,0.384,9.927,6.969,98.062,148.120,-1.314,0.619,39.979,...,0.012282,4002.0,5468.391,1876.136,1301.919,1270.931,1287.896,1257.880,1217.471,1129.529
1,0.0,2021/12/13 0:02,0.384,9.926,6.973,98.059,148.114,-1.314,0.620,39.949,...,0.012282,4002.0,5468.232,1876.303,1302.385,1271.365,1288.380,1258.347,1217.855,1129.879
2,0.0,2021/12/13 0:03,0.383,9.926,6.973,98.055,148.130,-1.310,0.619,39.931,...,0.010845,4002.0,5468.074,1876.469,1302.852,1271.798,1288.863,1258.813,1218.238,1130.229
3,0.0,2021/12/13 0:04,0.384,9.926,6.969,97.832,148.117,-1.310,0.618,39.912,...,0.012282,4002.0,5467.925,1876.636,1303.327,1272.238,1289.319,1259.246,1218.621,1130.579
4,0.0,2021/12/13 0:05,0.383,9.926,6.973,98.393,148.120,-1.314,0.619,39.924,...,0.014582,4002.0,5467.775,1876.803,1303.786,1272.664,1289.786,1259.679,1219.005,1130.929
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20740,13.0,2022/2/15 18:45,0.390,9.933,6.659,95.439,155.769,-1.282,0.621,40.778,...,1.225000,4002.0,2923.044,1198.167,1588.414,1475.414,1588.364,1475.364,1395.614,1211.014
20741,13.0,2022/2/15 18:46,0.389,9.933,6.664,95.524,155.769,-1.278,0.622,40.812,...,1.228000,4002.0,2923.034,1198.167,1588.416,1475.416,1588.366,1475.366,1395.616,1211.016
20742,13.0,2022/2/15 18:47,0.389,9.933,6.668,95.788,155.772,-1.278,0.623,40.808,...,1.231000,4002.0,2923.024,1198.167,1588.417,1475.417,1588.367,1475.367,1395.617,1211.017
20743,13.0,2022/2/15 18:48,0.390,9.933,6.664,95.519,155.760,-1.278,0.621,40.847,...,1.235000,4002.0,2923.014,1198.167,1588.419,1475.419,1588.369,1475.369,1395.619,1211.019


In [15]:
df2.set_index('datetime',inplace = True)
df2

Unnamed: 0_level_0,number,pi459b,fiq457b,fiq458b,fiq459b,fiq456b,ficq453b,pi461b,ti403b,fiq460b,...,dcs,label,ficq452b_set,ficq451b_set,or1_cur_b3_set,or2_cur_b3_set,or3_cur_b3_set,mr1_cur_b3_set,mr2_cur_b3_set,ir_cur_b3_set
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021/12/13 0:01,0.0,0.384,9.927,6.969,98.062,148.120,-1.314,0.619,39.979,62.181,...,0.012282,4002.0,5468.391,1876.136,1301.919,1270.931,1287.896,1257.880,1217.471,1129.529
2021/12/13 0:02,0.0,0.384,9.926,6.973,98.059,148.114,-1.314,0.620,39.949,58.438,...,0.012282,4002.0,5468.232,1876.303,1302.385,1271.365,1288.380,1258.347,1217.855,1129.879
2021/12/13 0:03,0.0,0.383,9.926,6.973,98.055,148.130,-1.310,0.619,39.931,54.896,...,0.010845,4002.0,5468.074,1876.469,1302.852,1271.798,1288.863,1258.813,1218.238,1130.229
2021/12/13 0:04,0.0,0.384,9.926,6.969,97.832,148.117,-1.310,0.618,39.912,56.428,...,0.012282,4002.0,5467.925,1876.636,1303.327,1272.238,1289.319,1259.246,1218.621,1130.579
2021/12/13 0:05,0.0,0.383,9.926,6.973,98.393,148.120,-1.314,0.619,39.924,59.238,...,0.014582,4002.0,5467.775,1876.803,1303.786,1272.664,1289.786,1259.679,1219.005,1130.929
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022/2/15 18:45,13.0,0.390,9.933,6.659,95.439,155.769,-1.282,0.621,40.778,58.615,...,1.225000,4002.0,2923.044,1198.167,1588.414,1475.414,1588.364,1475.364,1395.614,1211.014
2022/2/15 18:46,13.0,0.389,9.933,6.664,95.524,155.769,-1.278,0.622,40.812,56.069,...,1.228000,4002.0,2923.034,1198.167,1588.416,1475.416,1588.366,1475.366,1395.616,1211.016
2022/2/15 18:47,13.0,0.389,9.933,6.668,95.788,155.772,-1.278,0.623,40.808,63.143,...,1.231000,4002.0,2923.024,1198.167,1588.417,1475.417,1588.367,1475.367,1395.617,1211.017
2022/2/15 18:48,13.0,0.390,9.933,6.664,95.519,155.760,-1.278,0.621,40.847,62.888,...,1.235000,4002.0,2923.014,1198.167,1588.419,1475.419,1588.369,1475.369,1395.619,1211.019


In [16]:
elec_feature = ['ficq452b','ficq451b','pi454b','tic453b','ti452b','pi452b','ti451b','tic457b','fiq455b','fiq454b','dcs','or1_cur_b3','or2_cur_b3','mr2_cur_b3','ir_cur_b3']
target = ['or1_cur_b3_set','or2_cur_b3_set','mr2_cur_b3_set','ir_cur_b3_set']
df3 = df2.loc[:,elec_feature]
df3

Unnamed: 0_level_0,ficq452b,ficq451b,pi454b,tic453b,ti452b,pi452b,ti451b,tic457b,fiq455b,fiq454b,dcs,or1_cur_b3,or2_cur_b3,mr2_cur_b3,ir_cur_b3
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
2021/12/13 0:01,5453.152,1818.666,0.532,156.046,535.493,0.589,162.793,154.949,56.499,172.702,0.012282,1286.0,1256.0,1205.0,1119.0
2021/12/13 0:02,5496.660,1880.827,0.532,156.129,537.400,0.589,162.813,155.052,56.305,172.725,0.012282,1289.0,1257.0,1205.0,1117.0
2021/12/13 0:03,5458.606,1807.220,0.533,156.199,537.169,0.589,162.905,155.075,56.539,172.724,0.010845,1288.0,1258.0,1206.0,1118.0
2021/12/13 0:04,5509.111,1822.385,0.532,156.538,536.928,0.588,163.004,155.097,56.191,172.153,0.012282,1289.0,1259.0,1207.0,1119.0
2021/12/13 0:05,5481.219,1851.505,0.531,156.690,537.805,0.590,163.093,155.120,56.158,171.571,0.014582,1288.0,1258.0,1208.0,1122.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022/2/15 18:45,2931.621,1216.146,0.527,154.057,529.197,0.552,151.973,151.608,65.000,194.461,1.225000,1581.0,1475.0,1390.0,1213.0
2022/2/15 18:46,2911.799,1182.298,0.527,154.053,528.884,0.552,152.210,151.839,65.000,195.025,1.228000,1580.0,1469.0,1400.0,1212.0
2022/2/15 18:47,2932.131,1214.272,0.526,154.102,530.325,0.551,152.509,151.924,65.000,195.582,1.231000,1592.0,1466.0,1391.0,1212.0
2022/2/15 18:48,2928.280,1219.490,0.526,154.077,530.166,0.551,152.509,151.948,65.000,196.133,1.235000,1586.0,1476.0,1388.0,1203.0


In [20]:
scaler = MinMaxScaler()
df4 = scaler.fit_transform(df3)
df4 = pd.DataFrame(df4,columns= df3.columns)
df_train = df4.loc[:,elec_feature]

In [21]:
df_train

Unnamed: 0,ficq452b,ficq451b,pi454b,tic453b,ti452b,pi452b,ti451b,tic457b,fiq455b,fiq454b,dcs,or1_cur_b3,or2_cur_b3,mr2_cur_b3,ir_cur_b3
0,0.928344,0.702017,0.421053,0.860627,0.859771,0.788732,0.704437,0.774271,0.197103,0.268684,0.023038,0.804756,0.776245,0.831131,0.878486
1,0.936069,0.725293,0.421053,0.861445,0.863668,0.788732,0.704649,0.780946,0.190266,0.269252,0.023038,0.806633,0.777011,0.831131,0.876494
2,0.929312,0.697731,0.473684,0.862136,0.863196,0.788732,0.705623,0.782437,0.198513,0.269227,0.021893,0.806008,0.777778,0.831865,0.877490
3,0.938280,0.703409,0.421053,0.865479,0.862703,0.774648,0.706671,0.783863,0.186249,0.255143,0.023038,0.806633,0.778544,0.832599,0.878486
4,0.933328,0.714314,0.368421,0.866978,0.864495,0.802817,0.707614,0.785353,0.185086,0.240787,0.024871,0.806008,0.777778,0.833333,0.881474
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20650,0.480625,0.476400,0.157895,0.841012,0.846906,0.267606,0.589876,0.557745,0.496687,0.805387,0.989638,0.989362,0.944061,0.966960,0.972112
20651,0.477106,0.463725,0.157895,0.840973,0.846266,0.267606,0.592385,0.572715,0.496687,0.819299,0.992029,0.988736,0.939464,0.974302,0.971116
20652,0.480716,0.475698,0.105263,0.841456,0.849211,0.253521,0.595551,0.578224,0.496687,0.833037,0.994421,0.996245,0.937165,0.967695,0.971116
20653,0.480032,0.477652,0.105263,0.841209,0.848886,0.253521,0.595551,0.579780,0.496687,0.846628,0.997609,0.992491,0.944828,0.965492,0.962151


In [23]:
df2.reset_index(inplace =True)
df_target = df2.loc[:,'or1_cur_b3_set']
df_target

0        1301.919
1        1302.385
2        1302.852
3        1303.327
4        1303.786
           ...   
20650    1588.414
20651    1588.416
20652    1588.417
20653    1588.419
20654    1588.421
Name: or1_cur_b3_set, Length: 20655, dtype: float64

In [25]:
model = RNN(df_train.shape[1],df_train.shape[1],1000,1)

In [None]:
class RNN(nn.Module):
    """搭建rnn网络"""
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,)                  #传入四个参数，这四个参数是rnn()函数中必须要有的
        self.output_layer = nn.Linear(in_features=hidden_size, out_features=output_size)
    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # rnn_out (batch, time_step, hidden_size)
        rnn_out, h_state = self.rnn(x, h_state)     #h_state是之前的隐层状态
        out = []
        for time in range(rnn_out.size(1)):
            every_time_out = rnn_out[:, time, :]    #相当于获取每个时间点上的输出，然后过输出层
            out.append(self.output_layer(every_time_out))
        return torch.stack(out, dim=1), h_state     #torch.stack扩成[1, output_size, 1]

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_function = nn.MSELoss()                #损失函数设为常用的MES均方根误差函数