In [None]:
#!/usr/bin/env python3
# coding : utf-8
# author : 欧宁益
import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
import datetime
# import os
# import is_holiday
import requests
import json
# import calendar 
import torch
import pickle  
from pmdarima.arima import auto_arima

# from typing import Any, Dict
import jieqi
import lunarcalendar as lc
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.tuner import Tuner
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer,Baseline,NHiTS,NBeats
from pytorch_forecasting.data import NaNLabelEncoder,GroupNormalizer
from pytorch_forecasting.metrics import SMAPE,MultivariateNormalDistributionLoss,MQF2DistributionLoss
from lightning.pytorch.loggers import TensorBoardLogger
# sns.set_theme(font='Microsoft YaHei')
# warnings.filterwarnings('ignore')
import chinese_calendar as cc
from chinese_calendar.solar_terms import (
    SOLAR_TERMS_C_NUMS,
    SOLAR_TERMS_DELTA,
    SOLAR_TERMS_MONTH,
    SolarTerms,
)
from lightning.pytorch.tuner import Tuner

In [None]:
# 加载数据
base_path = 'c:/OuNingyi/21级工程管理欧宁益/论文基础数据/'
sale_df = pd.read_csv(base_path + 'sale_data.csv')
shop_info = pd.read_csv(base_path + 'shop_info.csv')
goods_info = pd.read_csv(base_path + 'goods_info.csv')

### 一、数据清洗和处理

In [None]:
# 节气计算函数
def get_solar_terms(date):
    """
    生成 24 节气
    通用寿星公式： https://www.jianshu.com/p/1f814c6bb475

    通式寿星公式：[Y×D+C]-L
    []里面取整数； Y=年数的后2位数； D=0.2422； L=Y/4，小寒、大寒、立春、雨水的 L=(Y-1)/4

    该函数由chinese_calendar的get_solar_terms函数基础上进行改进
    """    
    year, month = date.year, date.month
    if not 1900 <= year <= 2100:
        raise NotImplementedError("only year between [1900, 2100] supported")
    D = 0.2422
    result = []
    # 按月计算节气
    for solar_term in SOLAR_TERMS_MONTH[month]:
        nums = SOLAR_TERMS_C_NUMS[solar_term]
        C = nums[0] if year < 2000 else nums[1]
        # 2000 年的小寒、大寒、立春、雨水按照 20 世纪的 C 值来算
        if year == 2000 and solar_term in [
            SolarTerms.lesser_cold,
            SolarTerms.greater_cold,
            SolarTerms.the_beginning_of_spring,
            SolarTerms.rain_water,
        ]:
            C = nums[0]
        Y = year % 100
        L = int(Y / 4)
        if solar_term in [
            SolarTerms.lesser_cold,
            SolarTerms.greater_cold,
            SolarTerms.the_beginning_of_spring,
            SolarTerms.rain_water,
        ]:
            L = int((Y - 1) / 4)
        day = int(Y * D + C) - L
        # 计算偏移量
        delta = SOLAR_TERMS_DELTA.get((year, solar_term))
        if delta:
            day += delta
        _date = datetime.date(year, month, day)
        if date.day == _date.day:
            return solar_term.value[1]

In [None]:
# 转换为时间格式
sale_df.ds = pd.to_datetime(sale_df.ds)

In [None]:
# 生成销售日期列表
date_list = []
current_date = sale_df.ds.min()
while current_date <= sale_df.ds.max():
    date_list.append(current_date)  # 格式化日期为YYYYMMDD形式存入列表
    current_date += datetime.timedelta(days=1)  # 增加一天

In [None]:
time_df = pd.DataFrame(date_list,columns=['ds'])

time_df['year'] = time_df['ds'].dt.year
time_df['month'] = time_df['ds'].dt.month
time_df['day'] = time_df['ds'].dt.day
time_df['weekday'] = time_df['ds'].dt.weekday
time_df['is_weekend'] = time_df.weekday.map(lambda x:1 if x>=5 else 0)
# 计算本年度第几周
time_df['week_of_year'] = time_df['ds'].dt.isocalendar().week.astype(int)
# 计算本年第几天
time_df['day_of_year'] = time_df['ds'].dt.dayofyear
# 计算季度
time_df['quarter'] = time_df['ds'].dt.quarter
# 是否休息日
time_df['is_holiday'] = time_df.ds.map(lambda x:cc.is_holiday(x)*1)
# 节日具体名称
time_df['holiday_name'] = time_df.ds.map(lambda x:cc.get_holiday_detail(x)[1]).fillna(0)
# 农历日期
time_df['lunar_date'] = time_df.ds.map(lambda x:lc.Converter.Solar2Lunar(lc.Solar(x.year,x.month,x.day)))
time_df['lunar_year'] = time_df.lunar_date.map(lambda x:x.year)
time_df['lunar_month'] = time_df.lunar_date.map(lambda x:x.month)
time_df['lunar_day'] = time_df.lunar_date.map(lambda x:x.day)
time_df['lunar_is_leap'] = time_df.lunar_date.map(lambda x:x.isleap*1) # 是否闰月
# 计算节气
time_df['solar_terms'] = time_df.ds.map(lambda x:get_solar_terms(x)).fillna(0)

In [None]:
# 筛选有效数据超过60%的部分
day_cnt = sale_df[sale_df.sale_qty>0].groupby(['shop_code','goods_code']).sale_qty.count().rename('天数').reset_index()
sale_left_05 = day_cnt.loc[day_cnt.天数 > sale_df.ds.nunique()*0.5,['shop_code','goods_code']]
sale_df2 = sale_df.merge(sale_left_05,how='inner',on=['shop_code','goods_code'])

In [None]:
# 门店商品和日期的笛卡尔积
sale_df3 = pd.merge(sale_df2[['shop_code','goods_code']].drop_duplicates(), time_df.drop(columns='lunar_date'), how='cross').merge(sale_df2,how='left',on=['shop_code','goods_code','ds'])

In [None]:
# 空值用0替换
sale_df3.sale_amt.fillna(0, inplace=True)
sale_df3.sale_qty.fillna(0, inplace=True)
# 将销售量为负的，用0进行替换
sale_df3.loc[sale_df3.sale_qty < 0,'sale_qty'] = 0
sale_df3.loc[sale_df3.sale_amt < 0,'sale_amt'] = 0
# 增加时间索引
sale_df3['time_idx'] = sale_df3.groupby(['shop_code','goods_code']).ds.rank('min')

In [None]:
# 计算去年同期值
sale_df3['sale_qty_yoy'] = sale_df3.sale_qty.shift(365)
sale_df3['sale_amt_yoy'] = sale_df3.sale_amt.shift(365)

In [None]:
# 统计特征
sale_df3['log_sale_qty'] = sale_df3.sale_qty.apply(lambda x:np.log(x+ 1e-8)) # log压缩数值范围
sale_df3['avg_sale_qty_by_shop_sku'] = sale_df3.groupby(['shop_code','goods_code'],observed=True).sale_qty.transform('mean')
sale_df3['avg_sale_qty_by_sku'] = sale_df3.groupby('goods_code',observed=True).sale_qty.transform('sum') / sale_df3.time_idx.max()
sale_df3['avg_sale_qty_by_shop'] = sale_df3.groupby('shop_code',observed=True).sale_qty.transform('sum') / sale_df3.time_idx.max()

sale_df3['log_sale_amt'] = sale_df3.sale_amt.apply(lambda x:np.log(x+ 1e-8)) # log压缩数值范围
sale_df3['avg_sale_amt_by_shop_sku'] = sale_df3.groupby(['shop_code','goods_code'],observed=True).sale_amt.transform('mean')
sale_df3['avg_sale_amt_by_sku'] = sale_df3.groupby('goods_code',observed=True).sale_amt.transform('sum') / sale_df3.time_idx.max()
sale_df3['avg_sale_amt_by_shop'] = sale_df3.groupby('shop_code',observed=True).sale_amt.transform('sum') / sale_df3.time_idx.max()
# 去年同期值空值用平均值填充
sale_df3['sale_qty_yoy'].fillna(sale_df3['avg_sale_qty_by_shop_sku'],inplace=True)
sale_df3['sale_amt_yoy'].fillna(sale_df3['avg_sale_amt_by_shop_sku'],inplace=True)

In [None]:
# 转换格式
for i in sale_df3.select_dtypes(include=['float64']).columns:
    sale_df3[i] = sale_df3[i].astype('float32')
for j in sale_df3.select_dtypes(include=['int64']).columns:
    sale_df3[j] = sale_df3[j].astype(str).astype('int32')
for k in sale_df3.select_dtypes(include=['object']).columns:
    sale_df3[k] = sale_df3[k].astype(str).astype('category')
for l in ['shop_code','goods_code']+['is_weekend','is_holiday','holiday_name','lunar_is_leap','solar_terms']:
    sale_df3[l] = sale_df3[l].astype(str).astype('category')

##### 店铺信息处理

In [None]:
# 获取format_name字段唯一值
format_col_name = shop_info['format_name'].str.split(',', expand=True).stack().unique()
# 设置format_col_name对应的字段名称
format_is_fieldate = ['is_support_commercial_insurance','is_general_pharmacy','is_support_remote_medical_insurance_settlement','is_support_chronic_disease','is_contain_convenience_area',
                    'is_E_commerce_virtual_pharmacy','is_outpatient_coordination_pharmacy','is_DTP_pharmacy']
# 对format_name进行拆分
for k,v in zip(format_col_name, format_is_fieldate):
    shop_info[v] = shop_info['format_name'].apply(lambda x:1 if k in x else 0)

# 获取busi_district_type_name字段唯一值
busi_district_type_col_name = shop_info['busi_district_type_name'].str.split(',', expand=True).stack().unique()
# 设置busi_district_type_col_name对应的字段名称
busi_district_type_is_fieldate = ['is_hosptial_pharmacy','is_community_pharmacy','is_business_district_adjacent_street_pharmacy','is_vegetable_market_pharmacy',
                                'is_transportation_junction_pharmacy','is_business_district_shop_in_shop','is_tourist_attraction_pharmacy','is_tertiary_school_pharmacy',
                                'is_business_district_pharmacy','is_primary_school_pharmacy','is_airport_railway_station_pharmacy','is_park_pharmacy']
# 对busi_district_type_name进行拆分
for k,v in zip(busi_district_type_col_name, busi_district_type_is_fieldate):
    shop_info[v] = shop_info['busi_district_type_name'].apply(lambda x:1 if k in x else 0)   

# 将已拆分字段弃置
shop_prep = shop_info.drop(['format_name','busi_district_type_name'],axis=1)

In [None]:
# 连续变量空值填充为中位数，并标记
shop_prep['is_avg_sale_amt_na'] = 0
shop_prep.loc[shop_prep.avg_sale_amt.isnull(),'is_avg_sale_amt_na'] = 1
shop_prep.avg_sale_amt = shop_prep.avg_sale_amt.fillna(shop_prep.avg_sale_amt.median())

# 筛选文本分类变量
shop_str_list = shop_prep.select_dtypes(include=['object']).columns.to_list()

# 标记分类变量
shop_static_real_col = ['rental_area','use_area','store_area', 'busi_area', 'gd_lat', 'gd_lgt', 'dis_income','avg_sale_amt','month_age']
shop_static_cate_col = shop_prep.columns.drop(shop_static_real_col).to_list()

# 将分类变量转化类型
for sscc in shop_static_cate_col:
    shop_prep[sscc] = shop_prep[sscc].fillna(0).astype(str).astype('category')
for sscc in shop_prep.select_dtypes(include=['int64']).columns:
    shop_prep[sscc] = shop_prep[sscc].fillna(0).astype('int32')
for sscc in shop_prep.select_dtypes(include=['float64']).columns:
    shop_prep[sscc] = shop_prep[sscc].fillna(0).astype('float32')

##### 商品信息处理

In [None]:
# 筛选文本分类变量
goods_str_list = goods_info.select_dtypes(include=['object']).columns.to_list()

# 识别分类变量
goods_static_cate_col = goods_info.columns.to_list()

# 格式转换
for sscc2 in goods_static_cate_col:
    goods_info[sscc2] = goods_info[sscc2].fillna(0).astype(str).astype('category')
for sscc2 in shop_prep.select_dtypes(include=['int64']).columns:
    goods_info[sscc2] = goods_info[sscc2].fillna(0).astype('int32')
for sscc2 in shop_prep.select_dtypes(include=['float64']).columns:
    goods_info[sscc2] = goods_info[sscc2].fillna(0).astype('float32')

In [None]:
# 数据合并
main = sale_df3.merge(shop_prep,on='shop_code').merge(goods_info,on='goods_code')
main.dis_income = main.dis_income.astype('int32')
main.month_age = main.month_age.astype('int32')
main.time_idx = main.time_idx.astype('int32')
for col in main.select_dtypes(include=['float64']).columns:
    main[col] = main[col].astype('float32')
# 用于训练和验证的数据
sale_left_06 = day_cnt.loc[day_cnt.天数 > sale_df.ds.nunique()*0.6,['shop_code','goods_code']]
sale_left_06.shop_code = sale_left_06.shop_code.astype(str)
sale_left_06.goods_code = sale_left_06.goods_code.astype(str)
main_ts = main.merge(sale_left_06,how='inner',on=['shop_code','goods_code'])
# 用于后续预测需求进行深度学习的数据
main_dl = main[~main.index.isin(main_ts.index)]
# # 保存数据
main_ts.to_csv(base_path + 'main_ts.csv',index=False)
main_dl.to_csv(base_path + 'main_dl.csv',index=False)
# 保存数据格式
main_dtypes_dict = main.dtypes.drop(['ds','shop_code','goods_code']).to_dict()
main_dtypes_dict_path = base_path + 'main_dtypes_dict.json'
with open(main_dtypes_dict_path, 'wb') as f:  
    pickle.dump(main_dtypes_dict, f)
    pickle.dump(shop_static_cate_col, f)
    pickle.dump(goods_static_cate_col, f) 
    pickle.dump(shop_static_real_col, f)
        

### 四、TemporalFusionTransformer销售预测

In [None]:
# 读取数据格式
base_path = 'c:/OuNingyi/21级工程管理欧宁益/论文基础数据/'
main_dtypes_dict_path = base_path + 'main_dtypes_dict.json'
with open(main_dtypes_dict_path, 'rb') as f:  
    main_dtypes_dict = pickle.load(f)  
    shop_static_cate_col = pickle.load(f) 
    goods_static_cate_col = pickle.load(f) 
    shop_static_real_col = pickle.load(f)
# 读取主数据用于销售预测模型训练
main = pd.read_csv(base_path + 'main_ts.csv', dtype=main_dtypes_dict,parse_dates=['ds'])
main.shop_code = main.shop_code.astype(str).astype('category')
main.goods_code = main.goods_code.astype(str).astype('category')

In [None]:
# main.sales_scan_name.unique()
# main = main[main.sales_scan_name == '小店']

In [None]:
# arima只能对单时间序列进行预测,故分组循环预测
loss_list = []
for _, group_df in main[['shop_code','goods_code','ds','sale_qty']].groupby(['shop_code','goods_code']):
    arima = auto_arima(group_df['sale_qty'], seasonal=True)
    sale_qty_predicted = arima.predict(n_periods=7)
    # 计算损失
    loss_list.append(SMAPE()(torch.from_numpy(sale_qty_predicted.values).reshape(-1, 1),torch.from_numpy(group_df.loc[group_df.ds > '2023-07-24','sale_qty'].values).reshape(-1, 1)))
# 对各个时间序列的损失求均值
torch.stack(loss_list).mean()
    

In [None]:
# 一年预测7天
max_encoder_length = 365
max_prediction_length = 7

context_length = max_encoder_length
prediction_length = max_prediction_length

training_cutoff = main["time_idx"].max() - max_prediction_length
# 部分数据训练集, 用于展示图形
# main = main[main.goods_code.isin(['1020105787'])&(main.shop_code == '1215')]
training_data = main[lambda x: (x.time_idx <=  main["time_idx"].max()-max_prediction_length)].copy()

In [None]:
# 构造符合模型的数据集
training = TimeSeriesDataSet(
    training_data,
    time_idx = "time_idx",
    target = "sale_qty",    
    static_categoricals = shop_static_cate_col+goods_static_cate_col, # 门店信息和商品信息中的分类变量作为静态分类变量
    static_reals = shop_static_real_col , # 门店信息和商品信息中的连续变量作为静态连续变量
    time_varying_known_categoricals= ['is_weekend','is_holiday','holiday_name', 'lunar_is_leap','solar_terms'],
    time_varying_known_reals = ['time_idx','year', 'month', 'day', 'weekday',  'week_of_year','day_of_year', 'quarter', 'lunar_year', 'lunar_month', 'lunar_day' ,'sale_qty_yoy','sale_amt_yoy',
                                'avg_sale_qty_by_shop_sku','avg_sale_qty_by_sku','avg_sale_qty_by_shop','avg_sale_amt_by_shop_sku','avg_sale_amt_by_sku','avg_sale_amt_by_shop'],    
    time_varying_unknown_reals = ['sale_qty','sale_amt','log_sale_qty','log_sale_amt'],
    categorical_encoders = {label:NaNLabelEncoder(add_nan = True).fit(training_data[f'{label}']) for label in shop_static_cate_col + goods_static_cate_col}, # 对所有分类变量进行编码
    group_ids = ['shop_code','goods_code'],
    max_encoder_length = context_length,
    min_encoder_length = context_length,
    max_prediction_length = prediction_length,
    min_prediction_length = prediction_length,
    target_normalizer = GroupNormalizer(groups=['shop_code','goods_code'], transformation="softplus"),
    add_relative_time_idx = True,
    add_target_scales = True,
    add_encoder_length = True,
    allow_missing_timesteps = True
)
validation = TimeSeriesDataSet.from_dataset(training, main, predict=True, stop_randomization=True)

In [None]:
# 设置精度
torch.set_float32_matmul_precision('medium')
# 数据加载
batch_size = 128 
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8,persistent_workers=True,pin_memory=True)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=8,persistent_workers=True,pin_memory=True)

In [None]:
# 计算基准误差
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(torch.device('cuda:0'))
baseline_predictions = Baseline().predict(val_dataloader)
SMAPE()(baseline_predictions, actuals)

In [None]:
# 构造符合模型的NBeats和NHiTS的数据集
training_N = TimeSeriesDataSet(
    training_data,
    time_idx = "time_idx",
    target = "sale_qty",    
    # static_categoricals = shop_static_cate_col+goods_static_cate_col, # 门店信息和商品信息中的分类变量作为静态分类变量
    # static_reals = shop_static_real_col , # 门店信息和商品信息中的连续变量作为静态连续变量
    # time_varying_known_categoricals= ['is_weekend','is_holiday','holiday_name', 'lunar_is_leap','solar_terms'],
    # time_varying_known_reals = ['time_idx','year', 'month', 'day', 'weekday',  'week_of_year','day_of_year', 'quarter', 'lunar_year', 'lunar_month', 'lunar_day' ,'sale_qty_yoy','sale_amt_yoy',
    #                             'avg_sale_qty_by_shop_sku','avg_sale_qty_by_sku','avg_sale_qty_by_shop','avg_sale_amt_by_shop_sku','avg_sale_amt_by_sku','avg_sale_amt_by_shop'],    
    time_varying_unknown_reals = ['sale_qty'],
    categorical_encoders = {label:NaNLabelEncoder(add_nan = True).fit(training_data[f'{label}']) for label in shop_static_cate_col + goods_static_cate_col}, # 对所有分类变量进行编码
    group_ids = ['shop_code','goods_code'],
    max_encoder_length = context_length,
    min_encoder_length = context_length,
    max_prediction_length = prediction_length,
    min_prediction_length = prediction_length,
    target_normalizer = GroupNormalizer(groups=['shop_code','goods_code'], transformation="softplus"),
    add_relative_time_idx = False,
)
validation_N = TimeSeriesDataSet.from_dataset(training, main, predict=True, stop_randomization=True,min_prediction_idx=training_cutoff + 1)
# 设置精度
torch.set_float32_matmul_precision('medium')
# 数据加载
batch_size = 128 
train_dataloader_N = training_N.to_dataloader(train=True, batch_size=batch_size)
val_dataloader_N = validation_N.to_dataloader(train=False, batch_size=batch_size)


In [None]:
# NBeats参数寻优及训练
pl.seed_everything(24)
trainer_NBeats = pl.Trainer(accelerator="gpu", gradient_clip_val=1e-1)
net_NBeats = NBeats.from_dataset(training_N, learning_rate=3e-2, weight_decay=1e-2, widths=[32, 512], backcast_loss_ratio=0.1)
# find optimal learning rate
res_NBeats = Tuner(trainer_NBeats).lr_find(net_NBeats, train_dataloaders=train_dataloader_N, val_dataloaders=val_dataloader_N, min_lr=1e-5)
print(f"suggested learning rate: {res_NBeats.suggestion()}")
fig = res_NBeats.plot(show=True, suggest=True)
fig.show()
net_NBeats.hparams.learning_rate = res_NBeats.suggestion()

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
trainer_NBeats = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.01,
    callbacks=[early_stop_callback],
    limit_train_batches=150,
)


net_NBeats = NBeats.from_dataset(
    training_N,
    learning_rate=0.022387211385683406,
    log_interval=10,
    log_val_interval=1,
    weight_decay=1e-2,
    widths=[32, 512],
    backcast_loss_ratio=1.0,
)

trainer_NBeats.fit(
    net_NBeats,
    train_dataloaders=train_dataloader_N,
    val_dataloaders=val_dataloader_N,
)

In [None]:
# NBeats验证集的误差
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(torch.device('cuda:0'))
predictions = net_NBeats.predict(val_dataloader, return_y=True,trainer_kwargs=dict(accelerator="gpu")) 
SMAPE()(predictions.output, actuals)

In [None]:
# NHiTS参数寻优及训练
pl.seed_everything(24)
trainer_NHiTS = pl.Trainer(accelerator="gpu", gradient_clip_val=1e-1)
net_NHiTS = NHiTS.from_dataset(training_N, learning_rate=3e-2, weight_decay=1e-2,  backcast_loss_ratio=0.1)
# find optimal learning rate
res_NHiTS = Tuner(trainer_NHiTS).lr_find(net_NHiTS, train_dataloaders=train_dataloader_N, val_dataloaders=val_dataloader_N, min_lr=1e-5)
print(f"suggested learning rate: {res_NHiTS.suggestion()}")
fig = res_NHiTS.plot(show=True, suggest=True)
fig.show()
net_NHiTS.hparams.learning_rate = res_NHiTS.suggestion()

In [None]:
trainer_NHiTS = pl.Trainer(
    max_epochs=100,
    accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.01,
    callbacks=[early_stop_callback],
    limit_train_batches=150,
)


net_NHiTS = NHiTS.from_dataset(
    training_N,
    learning_rate=0.0031622776601683794,
    log_interval=10,
    log_val_interval=1,
    weight_decay=1e-2,
    backcast_loss_ratio=0.0,
    hidden_size=64,
    optimizer="AdamW",
    loss=MQF2DistributionLoss(prediction_length=max_prediction_length),
)

trainer_NHiTS.fit(
    net_NHiTS,
    train_dataloaders=train_dataloader_N,
    val_dataloaders=val_dataloader_N,
)

In [None]:
# NHiTS验证集的误差
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(torch.device('cuda:0'))
predictions = net_NHiTS.predict(val_dataloader, return_y=True,trainer_kwargs=dict(accelerator="gpu")) 
SMAPE()(predictions.output, actuals)

In [None]:
# 自动寻找最优超参数
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="optuna_test",
    n_trials=200,
    max_epochs=200,
    # gradient_clip_val_range=(0.01, 1.0),
    # hidden_size_range=(8, 128),
    # hidden_continuous_size_range=(8, 128),
    # attention_head_size_range=(1, 4),
    # learning_rate_range=(0.001, 0.1),
    # dropout_range=(0.1, 0.3),
    # trainer_kwargs=dict(limit_train_batches=1),
    reduce_on_plateau_patience=4,
)

with open("test_study.pkl", "wb") as fout:
    pickle.dump(study, fout)
    
print(study.best_trial.params)

In [None]:
# 正式训练参数设置
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()
trainer = pl.Trainer(
    max_epochs=200,
    accelerator="gpu",    
    enable_model_summary=True,
    gradient_clip_val=6.542287226245969,
    limit_train_batches=50,
    callbacks=[lr_logger, early_stop_callback],
    logger=TensorBoardLogger("lightning_logs"),
    log_every_n_steps=10
)
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.02818382931264452,
    hidden_size=20,  
    attention_head_size=1,
    dropout=0.2611769602088634, 
    hidden_continuous_size=14, 
    loss=SMAPE(),
    optimizer="Ranger",
    log_interval=10,
    reduce_on_plateau_patience=4    
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

In [None]:
# 正式训练
trainer.fit(
    tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
)

In [None]:
# 保存最佳模型
# best_model_path = trainer.checkpoint_callback.best_model_path
best_model_path = 'lightning_logs\\lightning_logs\\version_40\\checkpoints\\epoch=36-step=1850.ckpt'
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
# 验证集的误差
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(torch.device('cuda:0'))
predictions = best_tft.predict(val_dataloader, return_y=True,trainer_kwargs=dict(accelerator="gpu")) 
SMAPE()(predictions.output, actuals)