In [1]:
#合并下载数据
import pandas as pd
import os
import glob

import pandas as pd
import glob
import os

def merge_parquet(folder, output_file):
    files = glob.glob(os.path.join(folder, "*.parquet"))
    
    if not files:
        print("未找到Parquet文件。")
        return

    dfs = []
    # 如果输出文件也在这个文件夹里，需要避免把它自己也读进去（取决于你的文件名规则）
    # 建议先过滤掉 output_file
    files = [f for f in files if os.path.abspath(f) != os.path.abspath(output_file)]

    for f in files:
        try:
            df = pd.read_parquet(f)
            dfs.append(df)
        except Exception as e:
            print(f"跳过损坏文件: {f}, 错误: {e}")
    
    if not dfs:
        print("没有有效的数据被读取。")
        return

    # 合并数据
    df_merged = pd.concat(dfs, axis=0, ignore_index=True)
    df_merged = df_merged.dropna(axis=1, how='all')
    
    # ---------------------------------------------------------
    # 关键修改：步骤 1 - 先尝试保存文件
    # ---------------------------------------------------------
    try:
        df_merged.to_parquet(output_file)
        print(f"合并成功，文件已保存至: {output_file}")
        
        # -----------------------------------------------------
        # 关键修改：步骤 2 - 保存成功后，不删除源文件
        # -----------------------------------------------------
        if output_file in f:
            f.remove(output_file)
        for f in files:
            try:
                os.remove(f)
                print(f"已删除源文件: {f}")
            except OSError as e:
                print(f"无法删除文件 {f}: {e}")
                
        print("全部完成。")
        
    except Exception as e:
        # 如果保存失败，打印错误，并且绝对不要删除源文件
        print(f"!!! 保存失败 !!! 源文件未被删除。错误信息: {e}")




In [4]:
# 使用示例
markets=['ETH','BTC','SOL']
year=2025
mounth=12
days=['14']
nSigFigs=['4','5','4.0','5.0']
for market in markets:
    for day in days:
        for nsigfig in nSigFigs:
            input_path=f"/home/jack_li/python/LOB_research/fetch_data/data/{year}-{mounth}-{day}/{market}/nSigFigs={nsigfig}/"
            output_path=f"/home/jack_li/python/LOB_research/fetch_data/data/{year}-{mounth}-{day}/{market}/nSigFigs={nsigfig}/merged.parquet"
            print(input_path)
            print(output_path)
            merge_parquet(input_path, output_path)

/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=4/
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=4/merged.parquet
未找到Parquet文件。
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5/
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5/merged.parquet
未找到Parquet文件。
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=4.0/
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=4.0/merged.parquet
未找到Parquet文件。
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5.0/
/home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5.0/merged.parquet
合并成功，文件已保存至: /home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5.0/merged.parquet
已删除源文件: /home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/ETH/nSigFigs=5.0/l2book_1765667031827.parquet
已删除源文件: /home/jack_li/python/LOB_research/fetch_data/data

In [3]:
#处理数据原始代码
import pandas as pd
import numpy as np
import re

# ==========================================
# 1. 读取数据 (替换为你的 parquet 文件路径)
# ==========================================
# df = pd.read_parquet('你的文件路径.parquet')

# --- 模拟数据 (仅供演示，实际请注释掉下方代码) ---

df = pd.read_parquet(r'/home/jack_li/python/LOB_research/fetch_data/data/2025-11-25/BTC/merged.parquet')
df = df.sort_values(by="exchange_time")
df['time_diff']=df['exchange_time']-df['exchange_time'].shift(1)
# -----------------------------------------------

# ==========================================
# 2. 数据预处理：解析价格并计算中间价
# ==========================================
def parse_first_price(array_str):
    """解析字符串并提取第一个价格 (Level 1)"""
    try:
        clean_str = re.sub(r"[\[\]']", "", str(array_str))
        parts = clean_str.split()
        if parts:
            return float(parts[0])
        return np.nan
    except:
        return np.nan

print("正在计算 Mid-Price...")
# 提取最优买卖价
df['bid_px_1'] = df['bids_px'].apply(parse_first_price)
df['asks_px_1'] = df['asks_px'].apply(parse_first_price)

# 计算中间价 p_t [cite: 171]
df['mid_price'] = (df['bid_px_1'] + df['asks_px_1']) / 2


import pandas as pd
import numpy as np


# 假设 df 已经读入并按 exchange_time 排序
# df = df.sort_values(by="exchange_time")

# ==========================================
# 步骤 1: 定义断点并生成 Session ID
# ==========================================
# 你的 tick 间隔约 600ms，设定 1000ms (1s) 为阈值是合理的。
# 如果你想更宽容一点（允许偶尔的延迟），可以设为 2s 或 3s，视数据质量而定。
GAP_THRESHOLD = 1000  # 毫秒

# 计算时间差
df['time_diff'] = df['exchange_time'].diff()

# 标记是否断开：第一行(NaN) 或 间隔超过阈值
# fillna(0) 是为了让第一行不成为断点，或者你可以默认第一行是新组的开始
df['is_gap'] = (df['time_diff'] > GAP_THRESHOLD).fillna(False)

# 生成 Group ID: 每次遇到 True (断点)，ID 就加 1
df['session_id'] = df['is_gap'].cumsum()

print(f"数据被切分为 {df['session_id'].nunique()} 个连续片段。")

# ==========================================
# 步骤 2: 在组内计算 Label (定义函数)
# ==========================================
def calculate_labels_within_group(group_df, k_values=[3, 5, 10, 30 ,50], alpha=0.0002):
    # 确保组内数据按时间排序（通常已经是了）
    # group_df = group_df.sort_values('exchange_time')
    
    # 必须至少有 k_max + 1 行数据才能计算
    max_k = max(k_values)
    if len(group_df) < max_k + 1:
        # 如果片段太短，直接返回全 NaN 的结果 (或者你可以选择直接丢弃)
        # 这里为了保持结构，填充 NaN
        for k in k_values:
            group_df[f'label_{k}'] = np.nan
        return group_df

    # 计算中间价
    # 注意：这里假设 group_df 已经有了 mid_price 列
    
    for k in k_values:
        # 1. 过去均值 (m_minus)
        group_df[f'm_minus_{k}'] = group_df['mid_price'].rolling(window=k).mean()
        
        # 2. 未来均值 (m_plus)
        # 注意：不能直接用全局 shift，要在组内 shift
        group_df[f'm_plus_{k}'] = group_df['mid_price'].rolling(window=k).mean().shift(-k)
        
        # 3. 变化率
        change = (group_df[f'm_plus_{k}'] - group_df[f'm_minus_{k}']) / group_df[f'm_minus_{k}']
        

        
        # 清理临时列（可选）
        # group_df = group_df.drop(columns=[f'm_minus_{k}', f'm_plus_{k}'])
        
    return group_df

# ==========================================
# 步骤 3: 应用分组计算
# ==========================================
# 这一步会比较慢，因为要遍历很多组。如果数据量巨大，可以使用 polars 或优化 pandas 写法。
df_unlabeled = df.groupby('session_id', group_keys=False).apply(calculate_labels_within_group)

# 最后，删除那些无法计算 label 的行 (比如每个片段的头尾)
# df_final = df_labeled.dropna(subset=['label_100'])

import pandas as pd
import numpy as np

def apply_dynamic_labels(df):
    """
    基于已有的 m_minus 和 m_plus 列，计算动态阈值并更新 label 列。
    """
    # 这里的 k 值是根据你提供的 columns 推断出来的
    k_values = [3, 5, 10, 30, 50]
    
    print("开始计算动态阈值并更新标签...")
    
    for k in k_values:
        # 1. 构造列名
        m_minus_col = f'm_minus_{k}'
        m_plus_col = f'm_plus_{k}'
        label_col = f'label_{k}'
        
        # 检查列是否存在，防止报错
        if m_minus_col not in df.columns or m_plus_col not in df.columns:
            print(f"跳过 k={k}: 缺少对应的 m_minus 或 m_plus 列")
            continue
            
        # 2. 重新计算平滑收益率 (raw change)
        # 这一步是必要的，因为我们要根据这个分布来定阈值
        raw_change = (df[m_plus_col] - df[m_minus_col]) / df[m_minus_col]
        
        # 去除无效值以准确计算分位数
        valid_changes = raw_change.dropna()
        
        if valid_changes.empty:
            print(f"k={k}: 无有效数据，跳过")
            continue
            
        # 3. 计算动态阈值 (33% 和 66% 分位数)
        # 这样确保了数据大致按 1:1:1 分为 下跌:平稳:上涨
        threshold_down = valid_changes.quantile(0.3333)
        threshold_up = valid_changes.quantile(0.6667)
        
        print(f"k={k}: 下跌阈值 < {threshold_down:.6f}, 上涨阈值 > {threshold_up:.6f}")
        
        # 4. 生成新标签 (覆盖原有的 label 列)
        # 初始化为 0 (Stationary)
        new_labels = pd.Series(0, index=df.index)
        
        # 标记 +1 (Up)
        new_labels[raw_change > threshold_up] = 1
        
        # 标记 -1 (Down)
        new_labels[raw_change < threshold_down] = -1
        
        # 恢复那些原始数据无法计算的地方为 NaN (如数据开头或结尾)
        new_labels[raw_change.isna()] = np.nan
        
        # 赋值回 DataFrame
        df[label_col] = new_labels

    return df

# 使用函数
# 假设你的数据变量名是 df_unlabeled
df_labeled = apply_dynamic_labels(df_unlabeled)

# 检查一下结果分布 (验证是否平衡)
print("\n标签分布检查:")
for k in [3, 5, 10, 30, 50]:
    if f'label_{k}' in df_labeled.columns:
        print(f"\nLabel {k} 分布:")
        print(df_labeled[f'label_{k}'].value_counts(normalize=True).sort_index())
def top10(x):
    return pd.Series(x[:10])
df_out = pd.concat([
    df_labeled['asks_px'].apply(top10).astype(float).astype(int).add_prefix("ask_px_"),
    df_labeled['bids_px'].apply(top10).astype(float).astype(int).add_prefix("bid_px_"),
    df_labeled['asks_sz'].apply(top10).astype(float).add_prefix("ask_sz_"),
    df_labeled['bids_sz'].apply(top10).astype(float).add_prefix("bid_sz_"),
    df_labeled[['label_5', 'label_10', 'label_30', 'label_50']],
    df_labeled[['session_id']]
], axis=1)

# 去除 label 列（任意一个）为 NaN 的行
df_out = df_out.dropna(subset=['label_5', 'label_10', 'label_30', 'label_50'])


正在计算 Mid-Price...
数据被切分为 35 个连续片段。
开始计算动态阈值并更新标签...
k=3: 下跌阈值 < -0.000002, 上涨阈值 > 0.000004
k=5: 下跌阈值 < -0.000015, 上涨阈值 > 0.000017
k=10: 下跌阈值 < -0.000039, 上涨阈值 > 0.000042
k=30: 下跌阈值 < -0.000107, 上涨阈值 > 0.000108
k=50: 下跌阈值 < -0.000149, 上涨阈值 > 0.000160

标签分布检查:

Label 3 分布:
label_3
-1.0    0.333303
 0.0    0.333393
 1.0    0.333303
Name: proportion, dtype: float64

Label 5 分布:
label_5
-1.0    0.333303
 0.0    0.333394
 1.0    0.333303
Name: proportion, dtype: float64

Label 10 分布:
label_10
-1.0    0.333303
 0.0    0.333394
 1.0    0.333303
Name: proportion, dtype: float64

Label 30 分布:
label_30
-1.0    0.333303
 0.0    0.333395
 1.0    0.333303
Name: proportion, dtype: float64

Label 50 分布:
label_50
-1.0    0.333302
 0.0    0.333396
 1.0    0.333302
Name: proportion, dtype: float64


  df_unlabeled = df.groupby('session_id', group_keys=False).apply(calculate_labels_within_group)


In [5]:
#dataset preparing function define
import pandas as pd
import numpy as np
import re
import os

def parse_first_price_vectorized(series):
    """
    向量化解析价格字符串提取 Level 1 价格。
    如果数据本身已经是列表/数组格式，直接取第一个元素；如果是字符串则进行解析。
    """
    # 尝试直接获取第一个元素（假设是列表/数组）
    try:
        # 如果是 list/array 列，直接提取
        return series.str[0]
    except:
        pass

    # 如果是字符串格式 "['100', '200']"，使用正则解析
    def _parse(x):
        try:
            clean_str = re.sub(r"[\[\]']", "", str(x))
            parts = clean_str.split()
            if parts:
                return float(parts[0])
            return np.nan
        except:
            return np.nan
    
    return series.apply(_parse)

def process_lob_parquet(input_path, output_path, start_session_id=0):
    """
    处理 LOB 数据，计算 Label，展开 Feature，并输出到 Parquet。
    
    参数:
        input_path: 输入 parquet 路径
        output_path: 输出 parquet 路径
        start_session_id: 当前文件 session_id 的起始值 (用于跨文件递增)
        
    返回:
        next_session_id: 下一个文件应该使用的起始 session_id
    """
    
    print(f"--- 处理文件: {input_path} ---")
    
    # 1. 读取数据
    if not os.path.exists(input_path):
        print(f"错误: 文件不存在 {input_path}")
        return start_session_id

    df = pd.read_parquet(input_path)
    if df.empty:
        print("警告: 数据为空，跳过")
        return start_session_id

    # 按时间排序
    df = df.sort_values(by="exchange_time").reset_index(drop=True)

    # ==========================================
    # 2. 数据预处理：计算 Mid-Price
    # ==========================================
    print("计算 Mid-Price...")
    # 提取最优买卖价 (兼容字符串或列表格式)
    df['bid_px_1'] = parse_first_price_vectorized(df['bids_px']).astype(float)
    df['asks_px_1'] = parse_first_price_vectorized(df['asks_px']).astype(float)
    df['mid_price'] = (df['bid_px_1'] + df['asks_px_1']) / 2

    # ==========================================
    # 3. Session ID 切分 (递增处理)
    # ==========================================
    GAP_THRESHOLD = 1000  # 毫秒
    
    # 计算时间差
    df['time_diff'] = df['exchange_time'].diff()
    
    # 标记断点 (大于阈值或第一行)
    is_gap = (df['time_diff'] > GAP_THRESHOLD).fillna(False)
    
    # 生成 session_id，基础值加上 start_session_id
    # cumsum 从 0 或 1 开始，加上外部传入的 start_session_id
    df['session_id'] = is_gap.cumsum() + start_session_id
    
    current_max_session_id = df['session_id'].max()
    print(f"Session ID 范围: {df['session_id'].min()} -> {current_max_session_id}")

    # ==========================================
    # 4. 组内计算 Label (Rolling)
    # ==========================================
    print("计算 Label...")
    k_values = [5, 10, 30, 50] # 根据你的需求调整，移除了由小变大的冗余
    
    # 这里的技巧是利用 groupby().transform 保持索引对齐，避免 apply 的低效
    grouped = df.groupby('session_id')['mid_price']
    
    for k in k_values:
        # 过去 k 个的均值
        m_minus = grouped.transform(lambda x: x.rolling(window=k).mean())
        
        # 未来 k 个的均值 (shift(-k) 获取未来数据，再 rolling)
        # 注意：user 原逻辑是 rolling(k).mean().shift(-k)，即 "未来 k 个时间步处的那个时刻的 过去 k 均值"
        # 通常 DeepLOB 逻辑是：未来 k 个 tick 的平均价格 vs 当前 k 个 tick 的平均价格
        # 按照你原代码逻辑：
        m_plus = grouped.transform(lambda x: x.rolling(window=k).mean().shift(-k))
        
        # 计算变化率
        raw_change = (m_plus - m_minus) / m_minus
        
        # === 动态阈值 Labeling ===
        # 获取有效数据的 33% 和 66% 分位数
        valid_changes = raw_change.dropna()
        if valid_changes.empty:
            df[f'label_{k}'] = np.nan
            continue
            
        th_down = valid_changes.quantile(0.3333)
        th_up = valid_changes.quantile(0.6667)
        
        # 生成标签
        labels = pd.Series(0, index=df.index) # 默认为 0
        labels[raw_change > th_up] = 1
        labels[raw_change < th_down] = -1
        labels[raw_change.isna()] = np.nan # 保持无效值为 NaN
        
        df[f'label_{k}'] = labels

    # 删除无法计算 Label 的行 (通常是 session 尾部)
    # 只要主要 Label (例如 label_10) 是 NaN 就删除，或者删除所有 Label 都是 NaN 的
    label_cols = [f'label_{k}' for k in k_values]
    df = df.dropna(subset=label_cols)
    
    if df.empty:
        print("警告: 清洗后数据为空")
        return current_max_session_id + 1

    # ==========================================
    # 5. 特征展开 (Flattening) - 性能优化版
    # ==========================================
    print("展开 Feature 列 (Top 10)...")
    
    def expand_column(col_name, prefix, dtype=float):
        """将包含列表的列快速展开为多列"""
        # 假设列中已经是 list/array。如果是 string 格式的 list，需要先 eval (会变慢)
        # 这里假设 input parquet 读取出来已经是 array/list 结构
        # 取前 10 个元素
        expanded = pd.DataFrame(df[col_name].tolist()).iloc[:, :10]
        expanded.columns = [f"{prefix}{i}" for i in range(expanded.shape[1])]
        return expanded.astype(dtype)

    # 注意：如果 parquet 读入的是 string 形式的 "[1,2]"，这里会报错。
    # 鉴于你之前代码用了 .astype(float)，假设这里已经是数值型的 list
    try:
        # 尝试快速展开
        df_ask_px = expand_column('asks_px', 'ask_px_', float).astype(int) # 保持 int 
        df_bid_px = expand_column('bids_px', 'bid_px_', float).astype(int)
        df_ask_sz = expand_column('asks_sz', 'ask_sz_', float)
        df_bid_sz = expand_column('bids_sz', 'bid_sz_', float)
    except Exception as e:
        print(f"快速展开失败，尝试兼容模式 (可能数据是String): {e}")
        # 慢速兼容模式：如果数据是字符串
        import ast
        def safe_parse(x): 
            try: return ast.literal_eval(str(x))[:10]
            except: return [0]*10
        
        df_ask_px = pd.DataFrame(df['asks_px'].apply(safe_parse).tolist()).astype(float).astype(int).add_prefix('ask_px_')
        df_bid_px = pd.DataFrame(df['bids_px'].apply(safe_parse).tolist()).astype(float).astype(int).add_prefix('bid_px_')
        df_ask_sz = pd.DataFrame(df['asks_sz'].apply(safe_parse).tolist()).astype(float).add_prefix('ask_sz_')
        df_bid_sz = pd.DataFrame(df['bids_sz'].apply(safe_parse).tolist()).astype(float).add_prefix('bid_sz_')

    # ==========================================
    # 6. 合并输出
    # ==========================================
    # 确保索引对齐
    df_out = pd.concat([
        df[['exchange_time']].reset_index(drop=True),
        df_ask_px,
        df_bid_px,
        df_ask_sz,
        df_bid_sz,
        df[label_cols].reset_index(drop=True),
        df[['session_id']].reset_index(drop=True)
    ], axis=1)

    # 再次清洗可能的空值
    df_out = df_out.dropna()
    
    print(f"写入输出文件: {output_path}, Shape: {df_out.shape}")
# ... 前面的代码不变 ...

    print(f"写入输出文件: {output_path}, Shape: {df_out.shape}")
    
    # 检查文件是否存在
    file_exists = os.path.exists(output_path)
    
    # 使用 fastparquet 引擎，如果文件存在则 append=True，否则 append=False
    df_out.to_parquet(
        output_path, 
        compression='snappy', 
        engine='fastparquet',  # 必须指定引擎
        append=os.path.exists(output_path)     # 文件存在时追加，不存在时新建
    )
    
    # 返回下一个可用的 session_id
    return current_max_session_id + 1



In [6]:
# use prepareing function

markets=['BTC']
year=2025
mounth=12
days=['14']
nSigFigs=['5.0']
input_files = []
for market in markets:
    for day in days:
        for nsigfig in nSigFigs:
            input_path=f"/home/jack_li/python/LOB_research/fetch_data/data/{year}-{mounth}-{day}/{market}/nSigFigs={nsigfig}"
            output_path=f"/home/jack_li/python/LOB_research/fetch_data/data/{year}-{mounth}-{day}/{market}/nSigFigs={nsigfig}/merged.parquet"
            #print(input_path)
            input_files.append(output_path)


output_dir = r'/home/jack_li/python/LOB_research/fetch_data/data/BTC'
os.makedirs(output_dir, exist_ok=True)

# 初始化全局 session_id 计数器
global_session_id = 0
file_name = os.path.basename(input_files[0])
out_path = os.path.join(output_dir, f"LOB/processed_{file_name}")
if os.path.exists(out_path):
    os.remove(out_path)
    
for f_path in input_files:
    file_name = os.path.basename(f_path)
    print(file_name)
    # 调用函数，更新 global_session_id
    global_session_id = process_lob_parquet(f_path, out_path, start_session_id=global_session_id)
    
print("所有文件处理完成。")

merged.parquet
--- 处理文件: /home/jack_li/python/LOB_research/fetch_data/data/2025-12-14/BTC/nSigFigs=5.0/merged.parquet ---
计算 Mid-Price...
Session ID 范围: 0 -> 25
计算 Label...
展开 Feature 列 (Top 10)...
写入输出文件: /home/jack_li/python/LOB_research/fetch_data/data/BTC/LOB/processed_merged.parquet, Shape: (46196, 46)
写入输出文件: /home/jack_li/python/LOB_research/fetch_data/data/BTC/LOB/processed_merged.parquet, Shape: (46196, 46)
所有文件处理完成。


In [None]:
import pandas as pd
import numpy as np
df_out= pd.read_parquet(r'/home/jack_li/python/LOB_research/fetch_data/data/BTC/processed_merged.parquet')
WINDOW = 450000
# 1. Initialize the final lists of COLUMN NAMES
all_price_cols = []
all_size_cols = []
k_values=[5,10,30,50]
label_cols = [f'label_{k}' for k in k_values]

# 2. Use a loop/list comprehension to generate all 20 column names (10 levels * 2 sides)
for i in range(10):
    all_price_cols.append(f"ask_px_{i}")
    all_price_cols.append(f"bid_px_{i}")

    all_size_cols.append(f"ask_sz_{i}")
    all_size_cols.append(f"bid_sz_{i}")

# 3. Select the data in one step for pooling
price_data = df_out[all_price_cols] # This is a single DataFrame with 20 price columns
size_data = df_out[all_size_cols]   # This is a single DataFrame with 20 size columns

# 4. Calculate Pooled Stats (as discussed previously)
pooled_mean_price = np.mean(price_data.values)
pooled_devition_price = np.std(price_data.values)
print(pooled_mean_price)
print(pooled_devition_price)
all_cols = all_price_cols + all_size_cols
df = df_out[all_cols]

# 计算 rolling mean / std（不包括当前行）
rolling_mean = df.rolling(window=WINDOW, min_periods=WINDOW).mean().shift(1)
rolling_std  = df.rolling(window=WINDOW, min_periods=WINDOW).std().shift(1)

# Z-score
df_z = (df - rolling_mean) / rolling_std
df= pd.concat([
        df_z.reset_index(drop=True),
        df_out[label_cols].reset_index(drop=True),
        df_out[['session_id']].reset_index(drop=True)
    ], axis=1)
df=df.dropna(how='any')



In [None]:
#define dataset
import pandas as pd
import numpy as np
import torch
from torch.utils import data
import torch.nn.functional as F

# 保持原来的 data_classification 函数不变
def data_classification(X, Y, T):
    [N, D] = X.shape
    df = np.array(X)
    dY = np.array(Y)
    
    # 确保数据长度足够进行至少一次切片
    if N < T:
        return None, None

    dataY = dY[T - 1:N]
    dataX = np.zeros((N - T + 1, T, D))
    
    # 注意：这里使用循环效率较低，如果数据量巨大建议改用 stride_tricks
    for i in range(T, N + 1):
        dataX[i - T] = df[i - T:i, :]

    return dataX, dataY

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, df, k, num_classes, T):
        """Initialization""" 
        self.k = k
        self.num_classes = num_classes
        self.T = T
        
        # 容器，用于存放每个 session 处理后的数据
        all_x_list = []
        all_y_list = []
        
        # --- 核心修改开始 ---
        # 1. 按照 session_id 分组
        # 假设 session_id 是 df 的最后一列，或者直接用列名 'session_id'
        grouped = df.groupby('session_id')
        
        print(f"开始处理 {len(grouped)} 个 Session 的数据拼接...")

        for session_id, group_data in grouped:
            # 2. 对每个 Session 独立提取特征和标签
            # 确保只取特征列 (前40列)
            x_session = group_data.iloc[:, :40].to_numpy()
            
            # 提取标签列 (排除最后一列 session_id)
            # 根据你的 df_out 结构，labels 是倒数第5列到倒数第1列
            y_session = group_data.iloc[:, -5:-1].to_numpy() 

            # 3. 在组内进行滑动窗口切片 (T=100)
            x_processed, y_processed = data_classification(x_session, y_session, self.T)
            
            # 4. 如果该 Session 长度小于 T，会返回 None，需要跳过
            if x_processed is not None and len(x_processed) > 0:
                all_x_list.append(x_processed)
                all_y_list.append(y_processed)
        
        # 5. 将所有 Session 的结果拼接在一起
        if len(all_x_list) > 0:
            final_x = np.concatenate(all_x_list, axis=0)
            final_y = np.concatenate(all_y_list, axis=0)
        else:
            raise ValueError("数据不足：所有 Session 的长度都小于 T，无法生成数据集。")
            
        print(f"数据处理完成。总样本数: {final_x.shape[0]}")
        # --- 核心修改结束 ---

        # 接下来处理标签维度
        # final_y shape: (Total_Samples, 4) -> 取第 k 个 label
        final_y = final_y[:, self.k] - 1  # 假设 label 是 1,2,3 -> 转为 0,1,2
        
        self.length = len(final_x)
        
        # 转为 Tensor
        x_tensor = torch.from_numpy(final_x).float() # 建议转为 float
        self.x = torch.unsqueeze(x_tensor, 1) # (N, 1, T, D) 适合 CNN/LSTM
        self.y = torch.from_numpy(final_y).long() # Label 必须是 long/int 类型

    def __len__(self):
        """Denotes the total number of samples"""
        return self.length

    def __getitem__(self, index):
        """Generates samples of data"""
        return self.x[index], self.y[index]

# 使用示例
# dataset_train = Dataset(data=df_out, k=3, num_classes=3, T=100)

In [None]:
from torch.utils.data import DataLoader, random_split
dataset = Dataset(df, k=3, num_classes=3, T=100)
batch_size = 64
# 数据集大小
total_size = len(dataset)
train_size = int(total_size * 0.8)
val_size = total_size - train_size

# 随机划分
dataset_train, dataset_test = random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)



In [None]:
dataset_train

In [None]:
print(dataset_train.x.shape, dataset_train.y.shape)


In [None]:
class deeplob(nn.Module):
    def __init__(self, y_len):
        super().__init__()
        self.y_len = y_len
        
        # convolution blocks
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(1,2), stride=(1,2)),
            nn.LeakyReLU(negative_slope=0.01),
#             nn.Tanh(),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(32),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,2), stride=(1,2)),
            nn.Tanh(),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.Tanh(),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.Tanh(),
            nn.BatchNorm2d(32),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,10)),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(32),
        )
        
        # inception moduels
        self.inp1 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,1), padding='same'),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(64),
        )
        self.inp2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(64),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5,1), padding='same'),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(64),
        )
        self.inp3 = nn.Sequential(
            nn.MaxPool2d((3, 1), stride=(1, 1), padding=(1, 0)),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),
            nn.LeakyReLU(negative_slope=0.01),
            nn.BatchNorm2d(64),
        )
        
        # lstm layers
        self.lstm = nn.LSTM(input_size=192, hidden_size=64, num_layers=1, batch_first=True)
        self.fc1 = nn.Linear(64, self.y_len)

    def forward(self, x):
        # h0: (number of hidden layers, batch size, hidden size)
        h0 = torch.zeros(1, x.size(0), 64).to(device)
        c0 = torch.zeros(1, x.size(0), 64).to(device)
    
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        x_inp1 = self.inp1(x)
        x_inp2 = self.inp2(x)
        x_inp3 = self.inp3(x)  
        
        x = torch.cat((x_inp1, x_inp2, x_inp3), dim=1)
        
#         x = torch.transpose(x, 1, 2)
        x = x.permute(0, 2, 1, 3)
        x = torch.reshape(x, (-1, x.shape[1], x.shape[2]))
        
        x, _ = self.lstm(x, (h0, c0))
        x = x[:, -1, :]
        x = self.fc1(x)
        forecast_y = torch.softmax(x, dim=1)
        
        return forecast_y

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = deeplob(y_len = dataset.num_classes)
model.to(device)

In [None]:
summary(model, (1, 1, 100, 40))

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# A function to encapsulate the training loop
def batch_gd(model, criterion, optimizer, train_loader, test_loader, epochs):
    
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)
    best_test_loss = np.inf
    best_test_epoch = 0

    for it in tqdm(range(epochs)):
        
        model.train()
        t0 = datetime.now()
        train_loss = []
        for inputs, targets in train_loader:
            # move data to GPU
            inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64)
            # print("inputs.shape:", inputs.shape)
            # zero the parameter gradients
            optimizer.zero_grad()
            # Forward pass
            # print("about to get model output")
            outputs = model(inputs)
            # print("done getting model output")
            # print("outputs.shape:", outputs.shape, "targets.shape:", targets.shape)
            loss = criterion(outputs, targets)
            # Backward and optimize
            # print("about to optimize")
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        # Get train loss and test loss
        train_loss = np.mean(train_loss) # a little misleading
    
        model.eval()
        test_loss = []
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64)      
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss.append(loss.item())
        test_loss = np.mean(test_loss)

        # Save losses
        train_losses[it] = train_loss
        test_losses[it] = test_loss
        
        if test_loss < best_test_loss:
            torch.save(model, './best_val_model_pytorch')
            best_test_loss = test_loss
            best_test_epoch = it
            print('model saved')

        dt = datetime.now() - t0
        print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, \
          Validation Loss: {test_loss:.4f}, Duration: {dt}, Best Val Epoch: {best_test_epoch}')

    return train_losses, test_losses

In [None]:
tmp_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=1, shuffle=True)

for x, y in tmp_loader:
    print(x)
    print(y)
    print(x.shape, y.shape)
    break

In [None]:
train_losses, val_losses = batch_gd(model, criterion, optimizer, 
                                    train_loader, train_loader, epochs=50)