## Genearete labels

In [1]:
import pandas as pd
import nutils
import common as cm
import numpy as np
import os
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")


def calculate_stats(df, current, future):
    log_returns = np.log(df["mid_price"].shift(-60) / df["mid_price"]) * 1e4
    df[f"mean_{current}-{future}"] = (
        np.log(
            (df["mid_price"].rolling(future - current).mean().shift(-future))
            / df["mid_price"]
        )
        * 1e4
    )
    df[f"var_{current}-{future}"] = (
        log_returns.rolling(future - current).var().shift(-future)
    )
    df[f"vol_{current}-{future}"] = (
        log_returns.rolling(future - current).std().shift(-future)
    )


def calculate_stats2(df, current, future):
    df[f"max_{current}-{future}"] = (
        np.log(
            (df["mid_price"].rolling(future - current).max().shift(-future))
            / df["mid_price"]
        )
        * 1e4
    )
    df[f"min_{current}-{future}"] = (
        np.log(
            df["mid_price"].rolling(future - current).min().shift(-future)
            / df["mid_price"]
        )
        * 1e4
    )
    df[f"gap_{current}-{future}"] = (
        (
            df["mid_price"].rolling(future - current).max().shift(-future)
            - df["mid_price"].rolling(future - current).min().shift(-future)
        )
        / df["mid_price"]
        * 1e4
    )


def calculate_stats3(df, cur, future):
    results = []
    for name, group in df.groupby(df.index):
        # 计算未来收益率，超过当日数据长度时使用最后一行
        future_prices = group["mid_price"].shift(-future)
        last_valid_price = group["mid_price"].iloc[-1]
        future_prices.fillna(last_valid_price, inplace=True)
        ret = np.log(future_prices / group["mid_price"]) * 1e4
        ret = pd.DataFrame(ret, index=group.index, columns=[f"ret_{future}"])
        results.append(ret)
    return pd.concat(results)


stk_list = [code for code in cm.STK_CODES if code not in cm.SELECTED_CODES]

for code in tqdm(cm.STK_CODES):
    datas = cm.get_snapshot(code)
    df = datas["tickData"]
    tm = datas["timestamp"][:, 0]
    mask = tm < 20240101
    df = df[mask]
    tm = tm[mask]
    df = pd.DataFrame(df, columns=cm.COLS_SNAPSHOTS)
    df.index = tm

    # 保存原始列名
    original_columns = set(df.columns)
    df["mid_price"] = (df["AskPrice1"] + df["BidPrice1"]) / 2
    cur_futs = [(0, 60), (60, 120), (60, 300), (60, 600), (60, 6000), (120, 180)]
    for cur, futs in cur_futs:
        calculate_stats(df, cur, futs)
        calculate_stats2(df, cur, futs)
        df[f"ret_{futs}"] = calculate_stats3(df, cur, futs)

    # 只保留新生成的列
    new_columns = set(df.columns) - original_columns
    df = df.fillna(0)

    # 指定保存路径
    save_path = f"/mnt/disk1/multiobj_dataset/{code}/"
    os.makedirs(save_path, exist_ok=True)
    for col in new_columns:
        feature_path = os.path.join(save_path, f"{col}.npy")
        np.save(feature_path, df[col].values.reshape(-1, 1))

  1%|          | 6/997 [04:05<11:42:39, 42.54s/it]

In [None]:
cm.STK_CODES

[]

In [None]:
from numba import jit, prange


@jit(nopython=True, parallel=True)
def cal_label(mid_price, start, end, ts):
    ts_unique = np.unique(ts)

    means, vars, vols, maxs, mins, gaps, amplitudes = [], [], [], [], [], [], []
    for i in prange(len(ts_unique)):
        t = ts_unique[i]
        mask = ts == t
        mp = mid_price[mask]
        day_lenth = len(mp)
        for j in prange(len(mp)):
            price_now = mp[i]
            if j + end <= day_lenth:
                price_period = mp[j + start : j + end]
            elif (j + end) >= day_lenth and j + start < day_lenth:
                price_period = mp[j + start :]
            elif j + start >= day_lenth:
                price_period = np.zeros(end - start)

        mean = np.log(price_period.mean() / price_now) * 1e4
        var = np.log(price_period.var() / price_now) * 1e4
        vol = np.log(price_period.std() / price_now) * 1e4
        max = np.log(price_period.max() / price_now) * 1e4
        min = np.log(price_period.min() / price_now) * 1e4
        gap = np.log((price_period.max() - price_period.min()) / price_now) * 1e4
        amplitude = np.log((price_period.max() - price_period.min()) / price_now) * 1e4

        means.append(mean)
        vars.append(var)
        vols.append(vol)
        maxs.append(max)
        mins.append(min)
        gaps.append(gap)
        amplitudes.append(amplitude)

        mean = np.concatenate(means)
        var = np.concatenate(vars)
        vol = np.concatenate(vols)
        max = np.concatenate(maxs)
        min = np.concatenate(mins)
        gap = np.concatenate(gaps)
        amplitude = np.concatenate(amplitudes)
        return mean, var, vol, max, min, gap, amplitude