In [30]:
import os

import numpy as np
import torch
from src import train, predict
from src.utils import plot_fit_figures
import traceback
from src.models import LSTM
from sklearn.model_selection import ParameterGrid
import itertools
from src.utils import eval_func
from tqdm import tqdm

In [31]:
# 超参数的可能值
best_params = {
    'hidden_size': 56,
    'num_layers': 2,
    'batch_size': 64,
    'seq_len': 64,
}

hidden_size_options = [64]
num_layers_options = [2]
batch_size_options = [64]
seq_len_options = [64]
# 所有超参数组合
param_grid = list(itertools.product(hidden_size_options, num_layers_options, batch_size_options, seq_len_options))

In [32]:
def run(best_config: dict, fund_code: str):
    model = LSTM(input_size=7,
                 output_size=7,
                 hidden_size=best_config['hidden_size'],
                 num_layers=best_config['num_layers'],
                 )
    config = {
        'model': model,
        'fund_code': fund_code,
        'data_set_length': 1000,
        'batch_size': best_config['batch_size'],
        'num_epochs': 60,
        'seq_len': best_config['seq_len'],
    }
    train(**config)
    config.pop('num_epochs')
    config.pop('data_set_length')
    predictions, groundtruths = predict(**config)
    plot_fit_figures(fund_code=fund_code, predictions=predictions, groundtruths=groundtruths)
    mse, mae, rmse, smape = eval_func(predictions, groundtruths)
    print(f'mse: {mse}, mae: {mae}, rmse: {rmse}, smape: {smape}')
    return predictions, groundtruths


In [33]:
run(best_params, fund_code='000079')

-------数据集加载中-------
Dataset type: train, data length: 712
Dataset type: valid, data length: 203
--------加载完成---------
----开始在cuda上训练------


训练中: 000079: 100%|██████████| 60/60 [00:02<00:00, 27.40it/s]
  checkpoint = torch.load(f'checkpoints/{fund_code}.pt')


--------训练完成---------
-----loss曲线绘制完成-----
-------模型保存完成-------
Dataset type: test, data length: 223
---------开始预测---------


100%|██████████| 2/2 [00:00<00:00, 668.95it/s]


---------预测完成---------
---------评估完成---------
mse: 0.007078101392835379, mae: 0.06949492543935776, rmse: 0.08413144946098328, smape: 6.005788594484329


(array([[1.2022797, 1.2117378, 1.205519 , 1.1692517, 1.1016989, 1.2243798,
         1.141257 ],
        [1.2050085, 1.2143115, 1.2061437, 1.1686906, 1.1018659, 1.2233446,
         1.1402651],
        [1.2147235, 1.224754 , 1.209348 , 1.1724368, 1.1286553, 1.2253747,
         1.1585314],
        [1.2183926, 1.2282876, 1.2083845, 1.1584444, 1.1069214, 1.2113845,
         1.1385328],
        [1.2245911, 1.2344081, 1.210263 , 1.1590877, 1.1235596, 1.2121047,
         1.15018  ],
        [1.2271944, 1.2368506, 1.2103438, 1.1556032, 1.117052 , 1.2086236,
         1.1432732],
        [1.2238009, 1.232341 , 1.2098039, 1.1595483, 1.1070735, 1.2135777,
         1.1364169],
        [1.2187212, 1.2267685, 1.2085409, 1.1561732, 1.0921223, 1.2114788,
         1.1265485],
        [1.2151432, 1.223447 , 1.2076473, 1.1522233, 1.0890857, 1.2086954,
         1.1258984],
        [1.2107697, 1.2189124, 1.206722 , 1.1558558, 1.0877687, 1.2127783,
         1.1265327],
        [1.209665 , 1.2183716, 1.2069457

In [None]:
def grid_search(param_grid, fund_code):
    with open('grid_search_results.csv', 'w', encoding="utf-8") as f:
        f.write('hidden_size,num_layers,batch_size,seq_len,mse,mae,rmse,smape\n')
        for param in tqdm(param_grid):
            hidden_size, num_layers, batch_size, seq_len = param
            model = LSTM(input_size=7,
                         output_size=7,
                         hidden_size=hidden_size,
                         num_layers=num_layers,
                         )
            config = {
                'model': model,
                'fund_code': fund_code,
                'data_set_length': 1000,
                'batch_size': batch_size,
                'num_epochs': 60,
                'seq_len': seq_len,
            }
            try:
                train(**config)
                config.pop('num_epochs')
                config.pop('data_set_length')
                predictions, groundtruths = predict(**config)
                plot_fit_figures(fund_code=fund_code, predictions=predictions, groundtruths=groundtruths)
                mse, mae, rmse, smape = eval_func(predictions, groundtruths)
                f.write(f'{hidden_size},{num_layers},{batch_size},{seq_len},{mse},{mae},{rmse},{smape}\n')

            except Exception as e:
                print(param)
                print(traceback.format_exc())
                continue
