In [None]:
# 使用autodl首次下载模型时需要开启加速
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value


In [None]:
import torch
import pandas as pd

pd.options.mode.copy_on_write = True

In [None]:
# 全局配置参数
config = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    'etf_code': '510500',

    'data_dir': 'data',
    'sample_sequence_window': 90,
    'max_news_per_day': 20,
    'max_news_length': 456,

    'trainer_config': {
        'batch_size': 32,
        'num_epochs': 100,
        'patience': 10,
        'grad_clip': 1.0,

        'checkpoint_path': 'checkpoint/',
        'checkpoint_interval': 5,
        'model_save_path': 'model/',
    },

    'model_config': {
        'news_emb_aggregate_output_size': 256,
        'fusion_hidden': 256,
        'tech_feature_dim': 12,
        'pred_days': 5,
        'gru_hidden': 256,
        
    }
}
device = torch.device(config['device'])

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("yiyanghkust/finbert-tone-chinese")

In [None]:
from data_preprocessor import DataPreprocessor
# 准备数据
dp = DataPreprocessor(config, tokenizer)
dp.load_data_frame(from_csv=True)

In [None]:
from etf_dataset import ETFDataset
ds = ETFDataset(dp.etf_df, dp.preprocess_news(), tokenizer, config['sample_sequence_window'], config['model_config']['pred_days'])

In [None]:
print(f"etf_seq_size:\t{len(ds.etf_df)}\nnews_seq_size:\t{len(ds.news_dict)}\ndataset_size:\t{len(ds)}")
# print("-" * 20)

# for seq in [0, len(ds)-1]:
#     end_date = ds.dates[seq]
#     start_date = ds.etf_df.index[seq]
#     print(f"#### test ds.__get_item__({seq}), date: {start_date} -> {end_date} ####")
#     print(f"seq len = {len(ds.etf_df.loc[start_date:end_date]['date'])}")

# del seq, end_date, start_date

In [None]:
from trainer import Trainer
# 初始化训练器
trainer = Trainer(config, device)
trainer.init_dataloader(ds)

In [None]:
for batch in trainer.test_loader:
    tech_data = batch['tech_data'][0].unsqueeze(0).to(trainer.device)
    input_ids = batch['input_ids'][0].unsqueeze(0).to(trainer.device)
    attention_mask = batch['attention_mask'][0].unsqueeze(0).to(trainer.device)
    news_weights = batch['news_weights'][0].unsqueeze(0).to(trainer.device)
    
    outputs = trainer.model(tech_data, input_ids, attention_mask, news_weights)
    print(outputs)
    break