In [1]:
from models import *
from utils import *
from test import *

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import itertools

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from functools import reduce

In [2]:
X, y = CU_BEMS()
X_columns_to_normalize = X.columns.difference(['date'])
X_scaler = MinMaxScaler()
X[X_columns_to_normalize] = X_scaler.fit_transform(X[X_columns_to_normalize])

X_train = X[X["date"] <= '2019-10-01']
y_train = y[y["date"] <= '2019-10-01']

X_test = X[X["date"] >= '2019-10-01']
y_test = y[y["date"] >= '2019-10-01']

In [3]:
seq_length = 50
predict_length = 1
batch_size = 100

Train_dataset = TimeSeriesDataset_sep(X_train, y_train, seq_length, predict_length = predict_length)
Train_dataloader = DataLoader(Train_dataset, batch_size=batch_size, shuffle=False)

Test_dataset = TimeSeriesDataset_sep(X_test, y_test, seq_length, predict_length = predict_length)
Test_dataloader = DataLoader(Test_dataset, batch_size=batch_size, shuffle=False)



In [4]:
num_epochs = 50
input_size = len(X.columns)
output_size = 1
hidden_size = 100
num_layers = 2
num_heads = 5
model = BiLSTMTransformer(input_size, hidden_size, num_layers, output_size, num_heads, predict_length).to(device)



In [5]:
len(X.columns)

16

In [8]:
teacher_forcing_ratio = 0.8
criterion = nn.MSELoss()
#criterion = nn.SmoothL1Loss()  # 用于回归任务
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
loss_per_epoch = []
val_mse_per_epoch = []
val_r2_per_epoch = []

for epoch in range(num_epochs):
    model.train()  # 确保模型在训练模式下
    for external, internal, batch_y in Train_dataloader:
        external, internal, batch_y = external.to(device), internal.to(device), batch_y.to(device)

        # 前向传播
        total_loss = 0
        y = internal[:, -1:, :]
        for step in range(predict_length):
            outputs = model(external, internal, y)
            next_pred = outputs[:, -1:, :]
            #outputs = model(batch_X, batch_y.view(batch_y.shape[0], batch_y.shape[2]))
            loss = criterion(next_pred, batch_y[:, step:step+1, :])
            total_loss += loss
            if np.random.rand() < teacher_forcing_ratio:
                next_input = batch_y[:, step:step + 1, :]
            else:
                next_input = next_pred
            
            y = torch.cat([y, next_input], dim=1)

        total_loss = total_loss / predict_length
        # 反向传播和优化
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    #scheduler.step()

    loss_per_epoch.append(loss.item())
    teacher_forcing_ratio -= 0.05

    # 评估验证集
    val_loss, (val_mse, val_mae, val_r2) = evaluate_Transformer(model, Test_dataloader, criterion, device, [mean_squared_error, mean_absolute_error, r2_score])
    val_mse_per_epoch.append(val_mse)
    val_r2_per_epoch.append(val_r2)

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Val MSE: {val_mse:.4f}, Val MAE: {val_mae:.4f}, Val R²: {val_r2:.4f}')


Epoch [1/50], Loss: 46655.9336, Val Loss: 28623.4837, Val MSE: 28629.4023, Val MAE: 152.0427, Val R²: -0.0932
Epoch [2/50], Loss: 46774.6523, Val Loss: 29418.5055, Val MSE: 29407.7188, Val MAE: 156.0596, Val R²: -0.1229
Epoch [3/50], Loss: 46637.2383, Val Loss: 28976.4618, Val MSE: 28974.6855, Val MAE: 153.8848, Val R²: -0.1064
Epoch [4/50], Loss: 46616.5664, Val Loss: 29624.9703, Val MSE: 29610.1777, Val MAE: 157.0310, Val R²: -0.1307
Epoch [5/50], Loss: 47030.9062, Val Loss: 29872.2673, Val MSE: 29852.8340, Val MAE: 158.1595, Val R²: -0.1399
Epoch [6/50], Loss: 46185.8086, Val Loss: 27444.4499, Val MSE: 27481.0098, Val MAE: 144.8307, Val R²: -0.0494
Epoch [7/50], Loss: 46788.0977, Val Loss: 29544.8845, Val MSE: 29531.6328, Val MAE: 156.6573, Val R²: -0.1277
Epoch [8/50], Loss: 47205.6328, Val Loss: 30229.7971, Val MSE: 30203.9102, Val MAE: 159.7300, Val R²: -0.1533
Epoch [9/50], Loss: 47049.4375, Val Loss: 30375.2880, Val MSE: 30346.8555, Val MAE: 160.3519, Val R²: -0.1588
Epoch [10/

KeyboardInterrupt: 