In [50]:
import pandas as pd
import numpy as np
from tqdm import tqdm
data = pd.read_csv("./project1_data.csv")
# 将日期字符串转换为日期类型，并提取年份
data['year'] = pd.to_datetime(data['date'], format='%Y%m%d').dt.year

# 按股票代码分组
grouped_data = data.groupby('code')

sequence_length = 22  # 序列长度
features = ['open', 'high', 'low', 'close', 'volume', 'amount']  # 特征列

# 初始化列表
X_train_list = []
y_train_list = []
X_test_list = []
y_test_list = []
progress_bar = tqdm(total=len(grouped_data))
for code, group_df in grouped_data:

    if len(group_df) < sequence_length:
        # 如果分组长度不够，跳过该分组
        progress_bar.update(1)
        print("jump", code)
        continue

    # 按年份分割数据
    train_df = group_df[group_df['year'].isin([2010, 2011, 2012])]
    test_df = group_df[group_df['year'] == 2013]

    def generate_sequences(df, X_list, y_list):
        if len(df) >= sequence_length:
            X = []
            for i in range(len(df) - sequence_length):
                sequence_data = df[features].values[i:i+sequence_length]
                # Check the dimensionality of sequence_data
                if np.ndim(sequence_data) == 2:  # Ensure it's a 3D array
                    X.append(sequence_data)
            if X:  # Ensure X is not empty before appending
                X = np.array(X)
                y = df['label'].values[sequence_length:]
                X_list.append(X)
                y_list.append(y)


    # 对训练集和测试集数据分别生成序列
    generate_sequences(train_df, X_train_list, y_train_list)
    generate_sequences(test_df, X_test_list, y_test_list)
    progress_bar.update(1)

progress_bar.close()

# 组合训练集和测试集的数据
X_train = np.vstack(X_train_list)
y_train = np.concatenate(y_train_list)
X_test = np.vstack(X_test_list)
y_test = np.concatenate(y_test_list)

 81%|████████▏ | 2017/2480 [07:55<02:05,  3.68it/s]

jump 600591.SH


 82%|████████▏ | 2033/2480 [07:59<01:59,  3.74it/s]

jump 600607.SH


 91%|█████████ | 2250/2480 [08:57<01:04,  3.57it/s]

jump 600842.SH


100%|██████████| 2480/2480 [09:50<00:00,  4.20it/s]


In [51]:
import pickle

# 将变量保存到文件
with open("./data22_plus_amount.pkl", 'wb') as f:
    pickle.dump((X_train, X_test, y_train, y_test), f)