In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sqlalchemy import create_engine

def create_my_engine():
    engine = create_engine('mysql+pymysql://root:mysql@localhost:3306/financedb')
    return engine


In [8]:

class MyDataset(Dataset):
    def __init__(self, fund_code: str, sh_index_path: str, sz_index_path: str, start_date: str, end_date: str,
                 seq_len: int, type_: str):
        self.seq_len = seq_len

        # 加载数据
        fund_data = pd.read_sql(con=create_engine('mysql+pymysql://root:mysql@localhost:3306/financedb'),
                                sql=f'SELECT * FROM fund_data WHERE fund_code = "{fund_code}" AND date >= "{start_date}" AND date <= "{end_date}"')

        sh_index_data = pd.read_csv(sh_index_path)
        sz_index_data = pd.read_csv(sz_index_path)

        # 数据对齐
        fund_data['date'] = pd.to_datetime(fund_data['date'])
        sh_index_data['Date'] = pd.to_datetime(sh_index_data['Date'])
        sz_index_data['Date'] = pd.to_datetime(sz_index_data['Date'])

        # 重命名日期列以便合并
        sh_index_data.rename(columns={'Date': 'date'}, inplace=True)
        sz_index_data.rename(columns={'Date': 'date'}, inplace=True)

        # 合并数据
        merged_data = pd.merge(fund_data, sh_index_data[['date', 'Close', 'Volume']], on='date', how='inner')
        merged_data = pd.merge(merged_data, sz_index_data[['date', 'Close', 'Volume']], on='date', how='inner',
                               suffixes=('_sh', '_sz'))

        # 选择特征
        self.data = merged_data[['nav', 'c_nav', 'growth_rate', 'Close_sh', 'Volume_sh', 'Close_sz', 'Volume_sz']]
        self.df_data = self.data.copy()
        # 归一化
        self.scaler = MinMaxScaler()
        self.data = self.scaler.fit_transform(self.data)

        # 划分数据集
        split_idx = int(len(self.data) * 0.8)
        if type_ == 'train':
            self.data = self.data[:split_idx]
        else:
            self.data = self.data[split_idx:]

        self.data = torch.tensor(self.data, dtype=torch.float32)

    def __getitem__(self, idx):
        seq = self.data[idx:idx + self.seq_len, :]
        label = self.data[idx + self.seq_len, 0]  # 标签为次日的净值
        return seq, label

    def __len__(self):
        return len(self.data) - self.seq_len


In [9]:


dataset = MyDataset(fund_code='510050',
                    sh_index_path='../../data/raw/sh.csv',
                    sz_index_path='../../data/raw/sz.csv',
                    start_date='2020-01-01',
                    end_date='2023-01-01',
                    seq_len=96,
                    type_='train')
dataset.df_data

Unnamed: 0,nav,c_nav,growth_rate,Close_sh,Volume_sh,Close_sz,Volume_sz
0,2.647,3.871,0.49,3089.260010,217500.0,11015.990234,745500.0
1,2.634,3.856,-0.64,3073.699951,215600.0,10996.410156,785400.0
2,2.651,3.876,0.23,3087.399902,224600.0,11010.530273,914100.0
3,2.645,3.869,1.15,3095.570068,222200.0,11106.500000,822700.0
4,2.615,3.833,-0.30,3065.562988,206500.0,10978.990234,771800.0
...,...,...,...,...,...,...,...
723,3.031,4.173,-1.17,3066.893066,297900.0,10706.870117,202900.0
724,3.067,4.216,0.56,3104.802002,276600.0,10829.049805,171800.0
725,3.050,4.196,-0.72,3083.407959,312600.0,10698.269531,187700.0
726,3.072,4.222,-0.39,3083.785889,261500.0,10656.410156,159000.0
