In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from utils import od_dataloader,masked_mse, ODLoss
from config import get_configs
import numpy as np
import pandas as pd
from lib import train, test
from models import ContrastiveLoss

In [3]:
args = get_configs()

In [4]:
args.model_name = 'lstm'
args.epoch = 300
# args.batch_size = 64

In [5]:
se = torch.Tensor(np.load(args.se_path)).to(args.device)
pe = torch.Tensor(pd.read_csv(args.pe_path,index_col=0).values).to(args.device)
train_dataloader,val_dataloader, test_dataloader,z_adj = od_dataloader(args)

train number:832
validation number:128
test number:128


In [35]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)  # 输出与输入形状相同

    def forward(self, x, te, se, pe, adj,z_adj):
        B,T,N,N = x.size()
        x = x.reshape(B,T,N*N)
        out, _ = self.lstm(x)
        out = self.fc(out)
        out = out.reshape(B,T,N,N)
        return out


In [39]:
model = LSTMModel(66*66,128,1,66*66)

In [40]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = ODLoss
contrastive_loss = ContrastiveLoss(margin=2)
scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

In [41]:
train(model,train_dataloader,se,pe,z_adj,val_dataloader,optimizer,criterion,scheduler,args)

Epoch: 00, Loss: 3.9830
test loss: 2.7600
Epoch: 01, Loss: 2.8767
test loss: 2.4848
Epoch: 02, Loss: 2.7977
test loss: 2.4734
Epoch: 03, Loss: 2.7236
test loss: 2.3846
Epoch: 04, Loss: 2.6211
test loss: 2.3836
Epoch: 05, Loss: 2.5735
test loss: 2.2901
Epoch: 06, Loss: 2.4281
test loss: 2.1757
Epoch: 07, Loss: 2.4633
Epoch: 08, Loss: 2.4246
Epoch: 09, Loss: 2.3764
test loss: 2.1668
Epoch: 10, Loss: 2.4028
test loss: 2.1282
Epoch: 11, Loss: 2.3517
Epoch: 12, Loss: 2.3392
Epoch: 13, Loss: 2.2952
test loss: 2.0720
Epoch: 14, Loss: 2.2650
test loss: 2.0561
Epoch: 15, Loss: 2.1876
test loss: 2.0141
Epoch: 16, Loss: 2.1239
test loss: 1.9771
Epoch: 17, Loss: 2.0216
test loss: 1.9205
Epoch: 18, Loss: 1.9580
test loss: 1.8951
Epoch: 19, Loss: 1.9975
test loss: 1.8835
Epoch: 20, Loss: 1.9410
test loss: 1.8692
Epoch: 21, Loss: 1.9445
test loss: 1.8682
Epoch: 22, Loss: 1.9528
Epoch: 23, Loss: 1.8757
test loss: 1.8470
Epoch: 24, Loss: 1.8696
test loss: 1.8159
Epoch: 25, Loss: 1.8285
test loss: 1.808

In [42]:
mse,rmse,mae,mape,p_total,y_total = test(model,test_dataloader,se,pe,z_adj,criterion,args,True)

In [43]:
mse

1.622209906578064

In [44]:
rmse

1.2736600637435913

In [45]:
mae

0.5556039214134216

In [46]:
mape

1.6030184030532837

In [14]:
p_copy = p_total.reshape(48*24,66,66)
p_copy = p_copy.permute(1,2,0)
p_1 = p_copy[18][14]

RuntimeError: shape '[1152, 66, 66]' is invalid for input of size 13381632

In [None]:
y_copy = y_total.reshape(48*24,66,66)
y_copy = y_copy.permute(1,2,0)
y_1 = y_copy[18][14]

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(80, 5))
plt.plot(range(0,100), p_1[:100].cpu(), color='green', label='pred')
plt.plot(range(0,100), y_1[:100].cpu(), color='red', label='label')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
p_0 = torch.zeros(1,24,66,66).to(args.device)
p_total = torch.cat((p_total,p_0),dim=0)
p = p_total.reshape(49*24,66,66)
# p  = p_total.permute(1,0,2,3)
p_time = torch.mean(p, dim=1)
# p_time = torch.mean(p_time, dim=1)
p_time = torch.mean(p_time, dim=1)
p_time = p_time.reshape(7,7,24)
p_time = p_time[:,[3,4,5,6,0,1,2],:]
p_time = torch.mean(p_time, dim=0)
# p_time = p_time.reshape(168)

In [None]:
y_0 = torch.zeros(1,24,66,66).to(args.device)
y_total = torch.cat((y_total,y_0),dim=0)
y = y_total.reshape(49*24,66,66)
y_time =  torch.mean(y, dim=1)
# y_time =  torch.mean(y_time, dim=1)
y_time =  torch.mean(y_time, dim=1)
y_time = y_time.reshape(7,7,24)
y_time = y_time[:,[3,4,5,6,0,1,2],:]
y_time = torch.mean(y_time, dim=0)
# y_time = y_time.reshape(168)

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[1].cpu(), label='Mon_label')
plt.plot(range(6,30), p_time[1].cpu(), label='Mon_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[2].cpu(), label='Tue_label')
plt.plot(range(6,30), p_time[2].cpu(), label='Tue_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[3].cpu(), label='Wed_label')
plt.plot(range(6,30), p_time[3].cpu(), label='Wed_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[4].cpu(), label='Thu_label')
plt.plot(range(6,30), p_time[4].cpu(), label='Thu_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[5].cpu(), label='Fri_label')
plt.plot(range(6,30), p_time[5].cpu(), label='Fri_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[6].cpu(), label='Sat_label')
plt.plot(range(6,30), p_time[6].cpu(), label='Sat_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(6,30), y_time[0].cpu(), label='Sun_label')
plt.plot(range(6,30), p_time[0].cpu(), label='Sun_pred')
plt.legend()

plt.xlabel('time')
plt.ylabel('value')
plt.show()

In [None]:
y_weekday = y_total[:35]
y_weekday = y_weekday.reshape(5,7,24,66,66)
y_weekday = y_weekday[:,[3,4,5,6,0,1,2],:,:,:]
y_weekend = y_weekday[:,4:,:,:,:]
y_weekday = y_weekday[:,:4,:,:,:]
p_weekday = p_total[:35]
p_weekday = p_weekday.reshape(5,7,24,66,66)
p_weekday = p_weekday[:,[3,4,5,6,0,1,2],:,:,:]
p_weekend = p_weekday[:,4:,:,:,:]
p_weekday = p_weekday[:,:4,:,:,:]


In [None]:
weekday_mse = masked_mse(y_weekday,p_weekday)
weekday_mse

In [None]:
weekend_mse = masked_mse(y_weekend,p_weekend)
weekend_mse