In [39]:
import pandas as pd
import torch
import numpy as np
FULL_DATA = False

In [38]:
data = pd.read_csv('data/full_data.csv')


In [40]:
data['stock_id'].unique().shape

(200,)

In [41]:
from itertools import product
if FULL_DATA:
    # 创建完整的 time_id 和 stock_id 组合
    all_time_ids = data['time_id'].unique()
    all_stock_ids = range(200)  # 0到199
    all_combinations = pd.DataFrame(list(product(all_time_ids, all_stock_ids)), columns=['time_id', 'stock_id'])

    # 与原始数据集合并
    df_full = all_combinations.merge(data, on=['time_id', 'stock_id'], how='left')

    # 填充缺失的值
    df_full.fillna(0, inplace=True)

    # 根据需要排序
    df_full.sort_values(by=['time_id', 'stock_id'], inplace=True)

    # 保存处理后的数据集
    df_full.to_csv('data/full_data.csv', index=False)


In [42]:
data = df_full

In [50]:
data.shape[0] 

5291000

In [56]:
df = data[data['date_id']<=0]

In [57]:
features = [c for c in data.columns if c not in ['row_id', 'time_id', 'date_id', 'target']]

df的大小是(windows，stock，features)

In [61]:
df

Unnamed: 0,time_id,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,...,near_price_first,target_shift1,target_shift2,target_shift3,target_shift4,target_shift5,target_shift6,target_first,after55,median_size
0,0,0,0.0,0.0,3180602.69,1.0,0.999812,13380276.64,1.0,1.0,...,1.0,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-3.029704,0.0,42739.16
1,0,1,0.0,0.0,166603.91,-1.0,0.999896,1642214.25,1.0,1.0,...,1.0,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-5.519986,0.0,25548.50
2,0,2,0.0,0.0,302879.87,-1.0,0.999561,1819368.03,1.0,1.0,...,1.0,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-8.389950,0.0,26228.10
3,0,3,0.0,0.0,11917682.27,-1.0,1.000171,18389745.62,1.0,1.0,...,1.0,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-4.010200,0.0,41667.00
4,0,4,0.0,0.0,447549.96,-1.0,0.999532,17860614.95,1.0,1.0,...,1.0,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-0.047561,-7.349849,0.0,34014.58
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4278358,21391,158,0.0,0.0,0.00,0.0,0.000000,0.00,0.0,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.00
4278558,21392,158,0.0,0.0,0.00,0.0,0.000000,0.00,0.0,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.00
4278758,21393,158,0.0,0.0,0.00,0.0,0.000000,0.00,0.0,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.00
4278958,21394,158,0.0,0.0,0.00,0.0,0.000000,0.00,0.0,0.0,...,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.00


In [60]:
df.shape

(63613, 85)

In [59]:
df['time_id'].unique()

array([    0,     1,     2, ..., 21393, 21394, 24090], dtype=int64)

In [25]:
data[features].values.shape

(5291000, 81)

In [23]:
import numpy as np
import torch

def create_lstm_input(data, window_size, index, features):
    
    data_temp = data[(data['time_id']>=index) & (data['time_id']<=index+window_size-1)]
    all_stocks_data = [data_temp[data_temp['time_id'] == i][features].values for i in range(index, index+window_size)]
    return torch.tensor(np.array(all_stocks_data))

def create_target(data, window_size, index):
    data_temp = data[(data['time_id']>=index) & (data['time_id']<=index+window_size-1)]
    all_stocks_data = [data_temp[data_temp['time_id'] == i][['target']].values for i in range(index, index+window_size)]
    return torch.tensor(np.array(all_stocks_data))[-1, :, :]


In [None]:

batch_data = []
target_data = []
window_size = 55
for index in range(data['time_id'].max()-window_size+2):
    print(index)
    batch_data.append(create_lstm_input(data, window_size, index, features))
    target_data.append(create_target(data, window_size, index))

batch_data = torch.tensor(np.array(batch_data)) 
target_data = torch.tensor(np.array(target_data))




In [28]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TimeSeriesDataset(Dataset):
    def __init__(self, data, window_size, features):
        self.data = data
        self.window_size = window_size
        self.features = features

    def __len__(self):
        return self.data['time_id'].max() - self.window_size + 2

    def __getitem__(self, index):
        x = create_lstm_input(self.data, self.window_size, index, self.features)
        y = create_target(self.data, self.window_size, index)
        
        # 修改这里
        if isinstance(x, torch.Tensor):
            x = x.clone().detach()
        else:
            x = torch.tensor(x, dtype=torch.float)

        if isinstance(y, torch.Tensor):
            y = y.clone().detach()
        else:
            y = torch.tensor(y, dtype=torch.float)

        return x, y


# 创建数据集
dataset = TimeSeriesDataset(data, window_size, features)

# 创建数据加载器
data_loader = DataLoader(dataset, batch_size=512, shuffle=True)

# 使用数据加载器
for batch_data, target_data in data_loader:
    print(batch_data.shape)
    print(target_data.shape)
    break

KeyboardInterrupt: 

In [108]:
batch_data.shape

torch.Size([111, 55, 187, 13])

In [118]:
data['time_id'].max()-window_size+2

276

In [156]:
data[data['time_id']==219]['stock_id'].unique()

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  70,  71,  72,  73,  74,  75,  76,  77,  80,  81,
        82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
        95,  96,  97,  98,  99, 100, 103, 104, 105, 106, 107, 108, 109,
       110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
       123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 136, 137,
       138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 151,
       152, 154, 155, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167,
       168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
       181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 19

In [155]:
data[data['time_id']==220]['stock_id'].unique()

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  70,  71,  72,  73,  74,  75,  76,  77,  78,  80,
        81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,
        94,  95,  96,  97,  98,  99, 100, 103, 104, 105, 106, 107, 108,
       109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
       122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 134, 136,
       137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149,
       151, 152, 154, 155, 157, 159, 160, 161, 162, 163, 164, 165, 166,
       167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
       180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 19

In [138]:
data['time_id'].max()

329

In [157]:
data[data['time_id']==219]

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,bid_size,ask_price,ask_size,wap,target,time_id,row_id
40953,0,3,540,4785325.00,-1,0.998280,22303737.18,0.990578,0.995243,0.998280,52917.25,0.998388,160333.68,0.998307,-8.149743,219,3_540_0
40954,1,3,540,0.00,0,1.002769,3176068.93,1.002769,1.002769,1.002314,29254.18,1.002883,93807.56,1.002449,13.220310,219,3_540_1
40955,2,3,540,0.00,0,1.001531,2463111.96,1.001531,1.001531,1.000858,62809.50,1.001583,17599.40,1.001424,3.789663,219,3_540_2
40956,3,3,540,3883312.87,1,1.000434,74905949.11,1.001740,1.001740,1.000391,90767.05,1.000434,919.20,1.000434,-1.580119,219,3_540_3
40957,4,3,540,0.00,0,0.999825,30606040.06,0.999825,0.999825,0.999825,8643.70,0.999966,94243.80,0.999837,-1.860261,219,3_540_4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
41135,194,3,540,499978.56,-1,1.002884,10446613.35,1.002051,1.002051,1.002884,105311.34,1.003162,87494.75,1.003036,-0.500083,219,3_540_194
41136,195,3,540,1470152.67,1,1.000429,27617343.11,1.000875,1.000763,1.000317,196333.50,1.000429,30484.40,1.000414,-3.420114,219,3_540_195
41137,196,3,540,3058056.58,-1,1.001250,10105108.52,1.000246,1.000413,1.001250,34026.20,1.001752,61146.26,1.001430,-15.950203,219,3_540_196
41138,197,3,540,0.00,0,1.000623,18074458.04,1.000623,1.000623,1.000418,100039.01,1.000623,39408.50,1.000565,-1.189709,219,3_540_197
