In [1]:
import gc, os, copy, torch, pickle, bisect, logging, warnings
print(torch.cuda.is_available())
import pandas as pd
import numpy as np
# os.environ["POLARS_MAX_THREADS"] = "1"
import polars as pl
print(pl.thread_pool_size())
import sys 
import joblib
import kaggle_evaluation.jane_street_inference_server as js_server
import torch.nn as nn
import torch.optim as optim
from typing import Literal
from typing import Tuple, Union, List
from copy import deepcopy
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
import time
from tqdm import tqdm
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split
from multiprocessing import Manager
pd.set_option('display.max_colwidth', 2000)
warnings.filterwarnings('ignore')
logging.basicConfig(
    level=logging.DEBUG,  # 设置日志级别
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # 设置日志格式
    datefmt="%Y-%m-%d %H:%M:%S",  # 设置时间格式
)
logger = logging.getLogger()
time.time()

True
4


1736790023.323508

In [2]:
class CONFIG:
    seed = 666
    path = "/kaggle/input/jane-street-real-time-market-data-forecasting"

class SpecialCols:
    date_id = "date_id"
    time_id = "time_id"
    symbol_id = "symbol_id"
    weight_col = "weight"
    id_cols = [date_id, time_id, symbol_id]
    target_col = "responder_6"
    target_cols = ["responder_%d" % i for i in range(9)]
    feature_cols = [f"feature_{i:02}"  for i in range(79)]
    target_lag1d_cols =  ["responder_%d_lag1d" % i for i in range(9)]

In [3]:
def timing_decorator_with_params(name=""):
    """
    装饰器工厂函数，接受参数并返回装饰器。
    :param display_result: 是否显示函数的返回结果。
    """
    def timing_decorator(func):
        def wrapper(*args, **kwargs):
            start_time = time.time()  # 记录开始时间
            result = func(*args, **kwargs)  # 执行函数
            end_time = time.time()  # 记录结束时间
            print(f"Function '{name}--{func.__name__}' executed in {end_time - start_time:.6f} seconds")

            return result
        return wrapper
    return timing_decorator

测试

In [4]:
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
    """
    make the df index sorted

    df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
    This function could avoid such case

    """
    idx = df.index if axis == 0 else df.columns
    if (
        not idx.is_monotonic_increasing
        and isinstance(idx, pd.MultiIndex)
        and not idx.is_lexsorted()
    ):  
        return df.sort_index(axis=axis)
    else:
        return df

def np_ffill(arr: np.array):
    """
    forward fill a 1D numpy array

    Parameters
    ----------
    arr : np.array
        Input numpy 1D array
    """
    mask = np.isnan(arr.astype(float))  # np.isnan only works on np.float
    # get fill index
    idx = np.where(~mask, np.arange(mask.shape[0]), 0)
    np.maximum.accumulate(idx, out=idx)
    return arr[idx]

class TSDataSamplerJ:
    """
    (T)ime-(S)eries DataSampler
    This is the result of TSDatasetH

    It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
    dataset based on tabular data.
    - On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
      data.

    If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
    more powerful subclasses.

    Known Issues:
    - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result
      in a different data type

    """

    def __init__(
        self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
    ):
        """
        Build a dataset which looks like torch.data.utils.Dataset.

        Parameters
        ----------
        data : pd.DataFrame
            The raw tabular data
        start :
            The indexable start time
        end :
            The indexable end time
        step_len : int
            The length of the time-series step
        fillna_type : int
            How will qlib handle the sample if there is on sample in a specific date.
            none:
                fill with np.nan
            ffill:
                ffill with previous sample
            ffill+bfill:
                ffill with previous samples first and fill with later samples second
        flt_data : pd.Series
            a column of data(True or False) to filter data.
            None:
                kepp all data

        """
        self.start = start
        self.end = end
        self.step_len = step_len
        self.fillna_type = fillna_type
        self.data = lazy_sort_index(data)

        kwargs = {"object": self.data}
        if dtype is not None:
            kwargs["dtype"] = dtype

        self.data_arr = np.array(**kwargs)  # Get index from numpy.array will much faster than DataFrame.values!
        # NOTE:
        # - append last line with full NaN for better performance in `__getitem__`
        # - Keep the same dtype will result in a better performance
        self.data_arr = np.append(
            self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
        )
        self.nan_idx = -1  # The last line is all NaN

        # the data type will be changed
        # The index of usable data is between start_idx and end_idx
        self.idx_df, self.idx_map = self.build_index(self.data)
        self.data_index = deepcopy(self.data.index)

        if flt_data is not None:
            if isinstance(flt_data, pd.DataFrame):
                assert len(flt_data.columns) == 1
                flt_data = flt_data.iloc[:, 0]
            # NOTE: bool(np.nan) is True !!!!!!!!
            # make sure reindex comes first. Otherwise extra NaN may appear.
            flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
            self.flt_data = flt_data.values
            self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
            self.data_index = self.data_index[np.where(self.flt_data)[0]]
        self.idx_map = self.idx_map2arr(self.idx_map)

        self.start_idx, self.end_idx = self.data_index.slice_locs(
            start=start, end=end
        )
        self.idx_arr = np.array(self.idx_df.values, dtype=np.float64)  # for better performance

        # del self.data  # save memory

    @staticmethod
    def idx_map2arr(idx_map):
        # pytorch data sampler will have better memory control without large dict or list
        # - https://github.com/pytorch/pytorch/issues/13243
        # - https://github.com/airctic/icevision/issues/613
        # So we convert the dict into int array.
        # The arr_map is expected to behave the same as idx_map

        dtype = np.int64
        # set a index out of bound to indicate the none existing
        no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)

        max_idx = max(idx_map.keys())
        arr_map = []
        for i in range(max_idx + 1):
            arr_map.append(idx_map.get(i, no_existing_idx))
        arr_map = np.array(arr_map, dtype=dtype)
        return arr_map

    @staticmethod
    def flt_idx_map(flt_data, idx_map):
        idx = 0
        new_idx_map = {}
        for i, exist in enumerate(flt_data):
            if exist:
                new_idx_map[idx] = idx_map[i]
                idx += 1
        return new_idx_map

    def get_index(self):
        """
        Get the pandas index of the data, it will be useful in following scenarios
        - Special sampler will be used (e.g. user want to sample day by day)
        """
        return self.data_index[self.start_idx : self.end_idx]

    def config(self, **kwargs):
        # Config the attributes
        for k, v in kwargs.items():
            setattr(self, k, v)

    @staticmethod
    def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
        """
        The relation of the data

        Parameters
        ----------
        data : pd.DataFrame
            The dataframe with <datetime, DataFrame>

        Returns
        -------
        Tuple[pd.DataFrame, dict]:
            1) the first element:  reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
                instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009  ...
                datetime
                2021-01-11        0        1        2        3        4        5  ...
                2021-01-12     4146     4147     4148     4149     4150     4151  ...
                2021-01-13     8293     8294     8295     8296     8297     8298  ...
                2021-01-14    12441    12442    12443    12444    12445    12446  ...
            2) the second element:  {<original index>: <row, col>}
        """
        # object incase of pandas converting int to float
        idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
        idx_df = lazy_sort_index(idx_df.unstack())
        # NOTE: the correctness of `__getitem__` depends on columns sorted here
        idx_df = lazy_sort_index(idx_df, axis=1)

        idx_map = {}
        for i, (_, row) in enumerate(idx_df.iterrows()):
            for j, real_idx in enumerate(row):
                if not np.isnan(real_idx):
                    idx_map[real_idx] = (i, j)
        return idx_df, idx_map

    @property
    def empty(self):
        return len(self) == 0

    def _get_indices(self, row: int, col: int) -> np.array:
        """
        get series indices of self.data_arr from the row, col indices of self.idx_df

        Parameters
        ----------
        row : int
            the row in self.idx_df
        col : int
            the col in self.idx_df

        Returns
        -------
        np.array:
            The indices of data of the data
        """
        indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col]

        if len(indices) < self.step_len:
            indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])

        if self.fillna_type == "ffill":
            indices = np_ffill(indices)
        elif self.fillna_type == "ffill+bfill":
            indices = np_ffill(np_ffill(indices)[::-1])[::-1]
        else:
            assert self.fillna_type == "none"
        return indices

    def _get_row_col(self, idx) -> Tuple[int]:
        """
        get the col index and row index of a given sample index in self.idx_df

        Parameters
        ----------
        idx :
            the input of  `__getitem__`

        Returns
        -------
        Tuple[int]:
            the row and col index
        """
        # The the right row number `i` and col number `j` in idx_df
        if isinstance(idx, (int, np.integer)):
            real_idx = self.start_idx + idx
            if self.start_idx <= real_idx < self.end_idx:
                i, j = self.idx_map[real_idx]  # TODO: The performance of this line is not good
            else:
                raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
        elif isinstance(idx, tuple):
            # <TSDataSampler object>["datetime", "instruments"]
            date, inst = idx
            date = pd.Timestamp(date)
            i = bisect.bisect_right(self.idx_df.index, date) - 1
            # NOTE: This relies on the idx_df columns sorted in `__init__`
            j = bisect.bisect_left(self.idx_df.columns, inst)
        else:
            raise NotImplementedError(f"This type of input is not supported")
        return i, j

    def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]):
        """
        # We have two method to get the time-series of a sample
        tsds is a instance of TSDataSampler

        # 1) sample by int index directly
        tsds[len(tsds) - 1]

        # 2) sample by <datetime,instrument> index
        tsds['2016-12-31', "SZ300315"]

        # The return value will be similar to the data retrieved by following code
        df.loc(axis=0)['2015-01-01':'2016-12-31', "SZ300315"].iloc[-30:]

        Parameters
        ----------
        idx : Union[int, Tuple[object, str]]
        """
        # Multi-index type
        mtit = (list, np.ndarray)
        if isinstance(idx, mtit):
            indices = [self._get_indices(*self._get_row_col(i)) for i in idx]
            indices = np.concatenate(indices)
        else:
            indices = self._get_indices(*self._get_row_col(idx))

        # 1) for better performance, use the last nan line for padding the lost date
        # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
        # precision problems. It will not cause any problems in my tests at least
        indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)

        data = self.data_arr[indices]
        if isinstance(idx, mtit):
            # if we get multiple indexes, addition dimension should be added.
            # <sample_idx, step_idx, feature_idx>
            data = data.reshape(-1, self.step_len, *data.shape[1:])
        return data

    def __len__(self):
        return self.end_idx - self.start_idx

class DailyBatchSamplerRandom(Sampler):
    def __init__(self, data_source, shuffle=False, get_last_batch=False):
        self.data_source = data_source
        self.shuffle = shuffle
        # calculate number of samples in each batch
        self.daily_count = pd.Series(index=self.data_source.get_index(), dtype=pd.Float32Dtype).groupby("time_id").size().values
        self.daily_index = np.roll(np.cumsum(self.daily_count), 1)  # calculate begin index of each batch
        self.daily_index[0] = 0
        self.get_last_batch = get_last_batch

    def __iter__(self):
        index = np.arange(len(self.daily_count))
        if self.get_last_batch:
            index = index[::-1]
        for i in index:
            yield np.arange(self.daily_index[i], self.daily_index[i] + self.daily_count[i])

    def __len__(self):
        return len(self.data_source)


# 包装DataLoader生成4维输入
def batch_combine(data_loader, batch_size, drop_last=True):
    """
    将 DataLoader 的输出组合成批次，确保每个批次的股票数量 N 一致。

    Args:
        data_loader (DataLoader): PyTorch 数据加载器。
        batch_size (int): 合并的批次大小。

    Yields:
        torch.Tensor: 组合后的批次，形状为 [batch_size, N, step_len, feature_dim]。
    """
    batch_buffer = []  # 用于暂时存储每个小批次的数据
    current_N = None   # 当前批次的股票数

    for data in data_loader:
        # data.shape: [1, N, step_len, feature_dim]
        _, N, _, _ = data.shape

        if current_N is None:
            # 初始化 current_N
            current_N = N

        if N == current_N:
            # 当前批次的 N 与记录一致，将数据添加到 buffer
            batch_buffer.append(data)

            if len(batch_buffer) == batch_size:
                # 达到目标批次数量，合并并返回
                combined_batch = torch.cat(batch_buffer, dim=0)  # 合并成 [batch_size, N, step_len, feature_dim]
                yield combined_batch
                batch_buffer = []  # 清空 buffer
                current_N = None  # 重置 current_N

        else:
            # 当前 N 与记录不一致，先处理已有 buffer
            if batch_buffer:
                combined_batch = torch.cat(batch_buffer, dim=0)
                yield combined_batch
                batch_buffer = []

            # 更新 current_N 并重新开始
            current_N = N
            batch_buffer.append(data)

    # 处理剩余的 buffer（如果有），TOFIX，有bug，一定要drop_last
    if batch_buffer:
        if not drop_last:
            combined_batch = torch.cat(batch_buffer, dim=0)
            yield combined_batch

def get_dataloader(data: pl.DataFrame, step_len, get_last_batch=False):
    """
    data的格式 = dp.get()
    """
    # 获取DataLoader
    # counts = data.groupby('date_id')['time_id'].max() + 1
    # offsets = counts.shift(fill_value=0).cumsum()
    # data['offset'] = data['date_id'].map(offsets)
    # data['time_id'] += data['offset']
    # data.drop(columns=['offset', 'date_id'], inplace=True)
    # data.set_index(['time_id', 'symbol_id'], inplace=True)
    # # 重新排列列顺序
    # notin = ['weight', ] + ['responder_{}'.format(i) for i in range(0, 9)]
    # columns_order = [col for col in data.columns if col not in notin] + notin
    # data = data[columns_order].sort_index()
    # # 生成dataset
    # min_time_id = data.index.get_level_values('time_id').min()
    # max_time_id = data.index.get_level_values('time_id').max()

    data = data.sort([SpecialCols.date_id, SpecialCols.time_id])
    counts = data.group_by("date_id").agg(pl.col("time_id").max()).with_columns(pl.col("time_id") + 1)
    offsets = counts.with_columns(pl.col("time_id").shift(1, fill_value=0).cum_sum().alias("offset"))
    data = data.join(offsets.select(["date_id", "offset"]), on="date_id")
    data = data.with_columns((pl.col("time_id") + pl.col("offset")).alias("time_id"))
    data = data.drop(["offset", "date_id"])
    data = data.sort(["time_id", "symbol_id"])
    # 重新排列列顺序
    notin = ["weight"] + [f"responder_{i}" for i in range(9)]
    columns_order = [col for col in data.columns if col not in notin] + notin
    data = data.select(columns_order)
    
    # 获取 time_id 的最小值和最大值
    min_time_id = data.select(pl.col("time_id").min()).item()
    max_time_id = data.select(pl.col("time_id").max()).item()
    data = data.to_pandas().set_index(['time_id', 'symbol_id'])

    # 生成DataLoader
    dataset = TSDataSamplerJ(data=data, start=min_time_id, end=max_time_id, step_len=step_len, fillna_type='ffill+bfill')
    sampler = DailyBatchSamplerRandom(dataset, shuffle=False, get_last_batch=get_last_batch)
    loader = DataLoader(dataset, sampler=sampler) #, num_workers=4)
    return loader

In [5]:
# 模拟接口
def generate_data_batches(test_path, lags_path):
    date_ids = sorted(
        pl.scan_parquet(test_path)
        .select(pl.col("date_id").unique())
        .collect()
        .get_column("date_id")
    )
    assert date_ids[0] == 0

    for date_id in date_ids:
        test_batches = pl.read_parquet(
            os.path.join(test_path, f"date_id={date_id}"),
        ).group_by("time_id", maintain_order=True)

        lags = pl.read_parquet(
            os.path.join(lags_path, f"date_id={date_id}"),
        )

        for (time_id,), test in test_batches:
            test_data = (test, lags if time_id == 0 else None)
            validation_data = test.select('row_id')
            yield test_data, validation_data

In [6]:
# 定义一下剔除的特征，因为.pkl没有修改
eliminated_feats = ['feature_20_beta', 'feature_24_beta', 'feature_25_beta', 'feature_27_beta', 'feature_29_beta', 'responder_3_dailymean_lag1d_beta', 'responder_5_dailymean_lag1d_beta', 'responder_8_dailymean_lag1d_beta', 'responder_8_dailymean_lag1d_residual']+['feature_22_avg_same_timeid', 'feature_31_avg_same_timeid', 'feature_58_std_same_timeid', 'feature_60_std_same_timeid', 'feature_61_avg_same_timeid', 'feature_66_avg_same_timeid', 'feature_67_avg_same_timeid',
                                                                                                                                                                                                                                                                        'responder_6_lag1d_avg_same_timeid', 'responder_7_lag1d_avg_same_timeid', 'weight_avg_same_timeid']+['feature_05_beta_mean_lag_1_d', 'feature_06_beta_mean_lag_1_d', 'feature_06_beta_std_lag_1_d', 'feature_07_beta_std_lag_1_d', 'feature_16_beta_std_lag_1_d', 'feature_21_beta_std_lag_1_d', 'feature_37_residual_std_lag_1_d', 'feature_47_residual_mean_lag_1_d', 'feature_54_beta_last_lag_1_d', 'feature_56_beta_mean_lag_1_d', 'feature_57_beta_mean_lag_1_d', 'feature_61_beta_std_lag_1_d']

class DailyFeature:
    """一支票一个对象"""

    def __init__(self, instrument_id, time_id, window, features_n):
        """定义存储的数据
        Args:
            instrument_id (str): 标的的id
            window (int): 窗口的大小
            features_n(int) : 特征数
        """
        self.instrument_id = instrument_id
        self.time_id = time_id
        self.window = window
        self.features_n = features_n
        self.empty_result = np.full((1, features_n), np.nan)

    def load_new_value(self):
        """传入新的数据"""
        pass

    def get_cur_result(self):
        """定义计算逻辑"""
        pass


class TSMEAN_STD(DailyFeature):
    def __init__(self, instrument_id, time_id, window, features_n):
        super().__init__(instrument_id, time_id, window, features_n)

        # 初始化当前的数据
        self.x_sum = np.empty(shape=(0, features_n))
        self.xx_sum = np.empty(shape=(0, features_n))
        self.save_window = self.window + 1

        self.n = 0
        self.count = 0

    def load_new_value(self, xs_features_data: np.ndarray):
        """传入新的数据"""
        if self.n < self.window:
            self.n += 1
        self.count += 1

        # 防止溢出
        if self.count % self.window == 0:
            self.x_sum = self.x_sum - self.x_sum[0]
            self.xx_sum = self.xx_sum - self.xx_sum[0]

        # 填充nan为0
        xs_features_data = np.nan_to_num(xs_features_data, nan=0)
        assert self.x_sum.shape[1] == xs_features_data.shape[1]
        # 如果已有数据, 取cumsum后拼接
        if self.x_sum.shape[0] > 0:
            self.x_sum = np.concatenate(
                (self.x_sum, self.x_sum[-1] + xs_features_data), axis=0
            )[-self.save_window:]
            self.xx_sum = np.concatenate(
                (self.xx_sum, self.xx_sum[-1] +
                 xs_features_data * xs_features_data),
                axis=0,
            )[-self.save_window:]
        # 没有数据，直接拼接即可
        else:
            self.x_sum = np.concatenate((self.x_sum, xs_features_data), axis=0)[
                -self.save_window:
            ]
            self.xx_sum = np.concatenate(
                (self.xx_sum, xs_features_data * xs_features_data), axis=0
            )[-self.save_window:]

    def get_cur_result(self):
        if self.x_sum.shape[0] <= 1:
            return self.empty_result, self.empty_result
        else:
            n = self.n
            temp = self.x_sum[-1] - self.x_sum[0]
            mean = temp / n

            std = np.sqrt(
                (self.xx_sum[-1] - self.xx_sum[0]) /
                (n) - temp * temp / (n * n)
            )

            return mean, std

class HFDataProcessor:
    col_orders = ['row_id'] + SpecialCols.id_cols + [SpecialCols.weight_col] + ['is_scored'] + \
        SpecialCols.feature_cols + SpecialCols.target_lag1d_cols
    all_selected_feats = ['feature_00_market_simple_mean', 'feature_05_market_weighted_mean', 'feature_37_market_weighted_mean', 'feature_60_market_weighted_sum', 'feature_09_market_weighted_sum', 'feature_39_market_simple_mean', 'feature_39_market_weighted_mean', 'feature_10_market_weighted_sum', 'feature_01_market_weighted_sum', 'feature_58_market_weighted_sum', 'feature_62_market_weighted_mean', 'feature_45_market_simple_mean', 'feature_29_market_weighted_sum', 'feature_04_market_weighted_mean', 'feature_26_market_simple_mean', 'feature_60_market_simple_mean', 'feature_37_market_simple_mean', 'feature_42_market_simple_mean', 'feature_57_market_simple_mean', 'feature_29_market_simple_mean', 'feature_01_market_simple_mean', 'feature_73_market_simple_mean', 'feature_18_market_simple_mean', 'feature_11_market_weighted_sum', 'feature_46_market_simple_mean', 'feature_07_market_weighted_mean', 'feature_09_market_simple_mean', 'feature_21_market_simple_mean', 'feature_06_market_weighted_mean', 'feature_01_market_weighted_mean', 'feature_11_market_simple_mean', 'feature_02_market_simple_mean', 'feature_53_market_simple_mean', 'feature_78_market_simple_mean', 'feature_04_market_simple_mean', 'feature_59_market_weighted_sum', 'feature_49_market_simple_mean', 'feature_03_market_weighted_mean', 'feature_38_market_simple_mean', 'feature_05_market_simple_mean', 'feature_02_market_weighted_mean', 'feature_56_market_weighted_sum', 'feature_33_market_simple_mean', 'feature_49_market_weighted_sum', 'feature_53_market_weighted_mean', 'feature_08_market_simple_mean', 'feature_04_market_weighted_sum', 'feature_20_market_simple_mean', 'feature_08_market_weighted_mean', 'feature_37_market_weighted_sum', 'feature_59_market_weighted_sum_gpby_feature_09', 'feature_06_market_simple_mean_gpby_feature_09', 'feature_47_market_weighted_sum_gpby_feature_09', 'feature_58_market_weighted_sum_gpby_feature_09', 'feature_38_market_weighted_mean_gpby_feature_09', 'feature_49_market_simple_mean_gpby_feature_09', 'feature_68_market_weighted_sum_gpby_feature_09', 'feature_59_market_weighted_mean_gpby_feature_09', 'feature_48_market_weighted_sum_gpby_feature_09', 'feature_06_market_weighted_sum_gpby_feature_09', 'feature_38_market_weighted_sum_gpby_feature_09', 'feature_30_market_weighted_sum_gpby_feature_09',
                          'feature_60_market_weighted_sum_gpby_feature_09', 'feature_56_market_weighted_sum_gpby_feature_09', 'feature_56_market_simple_mean_gpby_feature_09', 'feature_59_weighted_signal', 'feature_36_weighted_signal', 'feature_31_weighted_signal', 'feature_56_weighted_signal', 'weight_weighted_signal', 'feature_72_weighted_signal', 'feature_60_weighted_signal', 'feature_45_weighted_signal', 'feature_55_weighted_signal', 'feature_09_weighted_signal', 'feature_58_weighted_signal', 'feature_07_last_lag_1_d', 'feature_01_mean_lag_1_d', 'feature_01_std_lag_1_d', 'feature_16_std_lag_1_d', 'feature_38_mean_lag_1_d', 'feature_08_mean_lag_1_d', 'feature_02_std_lag_1_d', 'feature_04_mean_lag_1_d', 'feature_05_mean_lag_1_d', 'feature_30_mean_lag_1_d', 'feature_04_std_lag_1_d', 'feature_05_last_lag_1_d', 'feature_38_last_lag_1_d', 'feature_61_mean_lag_1_d', 'feature_72_avg_same_symbol', 'feature_07_avg_same_symbol', 'feature_59_avg_same_symbol', 'feature_56_avg_same_symbol', 'feature_41_avg_same_symbol', 'feature_16_std_same_symbol', 'feature_37_avg_same_symbol', 'feature_06_std_same_symbol', 'feature_55_avg_same_symbol', 'feature_59_std_same_symbol', 'feature_52_std_same_symbol', 'feature_36_avg_same_symbol', 'feature_07_std_same_symbol', 'feature_58_avg_same_symbol', 'feature_19_avg_same_symbol', 'feature_08_avg_same_symbol', 'feature_02_avg_same_symbol', 'feature_68_avg_same_symbol', 'feature_38_std_same_symbol', 'feature_30_avg_same_symbol', 'feature_37_std_same_symbol', 'feature_52_avg_same_symbol', 'feature_18_avg_same_symbol', 'feature_51_avg_same_symbol', 'feature_38_avg_same_symbol', 'feature_70_avg_same_symbol', 'feature_05_avg_same_symbol', 'feature_58_std_same_symbol', 'feature_48_std_same_symbol', 'feature_15_avg_same_symbol', 'feature_54_avg_same_symbol', 'feature_01_avg_same_symbol', 'weight_avg_same_symbol', 'feature_04_avg_same_symbol', 'feature_66_avg_same_symbol', 'feature_00_avg_same_symbol', 'feature_57_avg_same_symbol', 'feature_60_avg_same_symbol', 'feature_60_std_same_symbol', 'feature_30_simple_lag_1', 'feature_59_simple_lag_1', 'feature_60_simple_lag_1', 'feature_69_simple_lag_1', 'feature_58_simple_lag_1', 'feature_07_simple_lag_1', 'feature_25_market_simple_mean_gpby_feature_10', 'feature_49_market_weighted_sum_gpby_feature_10', 'feature_24_market_simple_mean_gpby_feature_10', 'feature_07_zscore']
    function_to_use = pd.read_pickle(
        r"/kaggle/input/aft-js-dataset/function_to_use.pkl"
    )

    def __init__(self, ts_window=100) -> None:
        # logger.info("data preparing")
        # 原始数据
        self._data: pl.DataFrame = None
        # self._data_pd: pd.DataFrame = None

        self.return_data: pl.DataFrame = None
        self.time_queue = []

        self._keep_length = 6000  # 原始留存多少行的数据
        self._seq_len = 16  # 返回的截面数
        self._ts_window = ts_window

        # ffill相关
        self.ffill_values: pl.DataFrame = None
        # 最新的经过easy_process_raw_data处理的raw_data
        self.cur_processed_raw_data: pl.DataFrame = None

        # 存每一个time_id 的concat(原始特征，数据)
        self._cache = []

        # 下面的参数与日间数据特征计算有关
        self.symbols = set()
        self.new_symbols = set()
        self.tsmean_stder = dict()

        # 每日计算的特征
        self.ts_daily_lag_res = None

        self.clustered_features = {
            'class_1': ['feature_09', 'feature_10', 'feature_11', 'feature_20', 'feature_21', 'feature_22', 'feature_23', 'feature_24', 'feature_25', 'feature_26', 'feature_27', 'feature_28', 'feature_29', 'feature_30', 'feature_31'],
            'class_2': ['feature_07', 'feature_08', 'feature_37', 'feature_38', 'feature_41', 'feature_45', 'feature_49', 'feature_51', 'feature_52', 'feature_55', 'feature_56', 'feature_60'],
            'class_3': ['feature_02', 'feature_04', 'feature_06', 'feature_34', 'feature_36'],
            'class_4': ['feature_73', 'feature_74', 'feature_75', 'feature_76', 'feature_77', 'feature_78'],
            'class_5': ['feature_15', 'feature_16', 'feature_17', 'feature_61', 'feature_62', 'feature_63', 'feature_64'],
            'class_6': ['feature_05', 'feature_39', 'feature_42', 'feature_47', 'feature_50', 'feature_53', 'feature_58'],
            'class_7': ['feature_40', 'feature_43', 'feature_44', 'feature_46', 'feature_48', 'feature_54', 'feature_57', 'feature_59'],
            'class_8': ['feature_18', 'feature_19', 'feature_65', 'feature_66'],
            'class_9': ['feature_00', 'feature_01', 'feature_03', 'feature_32', 'feature_33', 'feature_35'],
            'class_10': ['feature_12', 'feature_13', 'feature_14', 'feature_67', 'feature_68', 'feature_69', 'feature_70', 'feature_71', 'feature_72']
        }
        self.pca_models = {
            class_name: joblib.load(
                f'/kaggle/input/aft-js-dataset/pca_model_f79/pca_model_1_{class_name}.pkl')
            for class_name in self.clustered_features.keys()
        }  # load 模型很慢 提前准备好

    def add_pca(self, new_data):
        for class_name, class_list in self.clustered_features.items():
            reduction_feature_values = new_data[class_list].to_numpy()
            new_factor = self.pca_models[class_name].transform(
                reduction_feature_values)
            new_data = new_data.with_columns(
                pl.Series(f"pca_{class_name}", new_factor.reshape(-1)))
        return new_data

    @staticmethod
    def easy_process_raw_data(data: pl.DataFrame, lag1_responder: pl.DataFrame):
        """
        data: 原始数据带responder(可以填0)
        lag1_responder: column为responder,值滞后一天.
        """
        lag1_responder = lag1_responder.rename(
            {col: col + "_lag1d" for col in SpecialCols.target_cols}
        )
        data = data.join(lag1_responder, on=SpecialCols.id_cols, how="left")[
            HFDataProcessor.col_orders]

        return data

    # @timing_decorator_with_params("hf")
    def _trans_raw_data_2pd(self, data: pl.DataFrame):
        """讲raw_data转化为pd.DataFrame"""
        data = data.to_pandas()
        data = data.set_index(SpecialCols.id_cols).sort_index()
        data["time_id_col"] = data.index.get_level_values(SpecialCols.time_id)
        data["symbol_id_col"] = data.index.get_level_values(
            SpecialCols.symbol_id)
        data["date_id_col"] = data.index.get_level_values(SpecialCols.date_id)
        return data

    # @timing_decorator_with_params("hf")
    def _trans_back_2pl(self, data: pd.DataFrame):
        """将reg的result转化回pl.DataFrame,添加上ids_col"""
        return pl.DataFrame(data.reset_index()).with_columns(pl.col(SpecialCols.symbol_id).cast(pl.Int8))

    #
    # @timing_decorator_with_params("hf")
    def xs_signal_weighted(
        self,
        feats: List[str] = None,
    ) -> pl.DataFrame:
        pl_data = (
            self.cur_processed_raw_data[SpecialCols.id_cols + feats]
            if feats is not None
            else self.cur_processed_raw_data.clone()
        )

        feat_cols = [
            feat for feat in pl_data.columns if feat not in SpecialCols.id_cols]

        cols, data = feat_cols, pl_data[feat_cols].to_numpy()
        row_medians = np.median(data, axis=0, keepdims=True)
        normalized_data = data - row_medians
        positive_sums = np.nansum(
            np.where(normalized_data > 0, normalized_data, 0), axis=0)
        negative_sums = np.nansum(
            np.where(normalized_data < 0, normalized_data, 0), axis=0)
        result = np.where(
            normalized_data > 0, normalized_data / positive_sums,
            np.where(normalized_data < 0, - normalized_data / negative_sums, 0)
        )
        result = pl.DataFrame(result)
        result.columns = [col + '_weighted_signal' for col in cols]
        result = pl.concat(
            [pl_data[SpecialCols.id_cols], result], how='horizontal')
        return result

    #
    # @timing_decorator_with_params("hf")
    def xs_market(
        self,
        feats: List[str] = None,
        groupby_col: Literal["feature_09", "feature_10"] = None,
    ) -> pl.DataFrame:
        assert groupby_col in [
            None,
            "feature_09",
            "feature_10",
        ], "groupby_col is invalid"
        if (groupby_col is not None) and (groupby_col not in feats):
            feats.append(groupby_col)
        data = (
            self.cur_processed_raw_data[SpecialCols.id_cols +
                                        [SpecialCols.weight_col] + feats]
            if feats is not None
            else self.cur_processed_raw_data.clone()
        )

        feat_cols = [
            feat
            for feat in data.columns
            if feat not in [SpecialCols.weight_col] + SpecialCols.id_cols
        ]

        # 剔除groupby_col
        if groupby_col is not None:
            feat_cols = [col for col in feat_cols if col != groupby_col]

        standard_weight = data[SpecialCols.weight_col] / \
            data[SpecialCols.weight_col].sum()

        standrad_data = data.with_columns(
            [(pl.col(col) * standard_weight) for col in feat_cols])

        if groupby_col is not None:
            reshaped_mean = data[feat_cols + [groupby_col]].group_by(groupby_col, maintain_order=True).mean().rename(
                lambda column_name: column_name + "_market_simple_mean_gpby_" + groupby_col if column_name != groupby_col else column_name)
            reshaped_weighted_mean = standrad_data.group_by(groupby_col, maintain_order=True).mean(
            ).rename(lambda column_name: column_name + "_market_weighted_mean_gpby_" + groupby_col)
            reshaped_weighted_sum = standrad_data.group_by(groupby_col, maintain_order=True).sum(
            ).rename(lambda column_name: column_name + "_market_weighted_sum_gpby_" + groupby_col)
            output = pl.concat([reshaped_mean, reshaped_weighted_mean, reshaped_weighted_sum],
                               how='horizontal').with_columns(pl.col(groupby_col).cast(data[groupby_col].dtype))
            return data.select(SpecialCols.id_cols + [groupby_col]).join(output, how='left', on=groupby_col)
        else:
            reshaped_mean = data[feat_cols + ["time_id"]].mean().rename(lambda column_name: column_name +
                                                                        "_market_simple_mean" if column_name != "time_id" else column_name)
            reshaped_weighted_mean = standrad_data.mean().rename(
                lambda column_name: column_name + "_market_weighted_mean")
            reshaped_weighted_sum = standrad_data.sum().rename(
                lambda column_name: column_name + "_market_weighted_sum")
            output = pl.concat([reshaped_mean, reshaped_weighted_mean, reshaped_weighted_sum],
                               how='horizontal').with_columns(pl.col("time_id").cast(data["time_id"].dtype))
            return data.select(SpecialCols.id_cols).join(output, how='left', on="time_id")
    #

    # @timing_decorator_with_params("hf")
    def ts_mean_std_same_symbol(self, feats: List[str] = None) -> pl.DataFrame:
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 剔除target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        # 只选取需要的数据, 因为是日内进行, 所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id))

        max_time_id = data[SpecialCols.time_id].max()
        if max_time_id < 29:
            # 这里返回的是空的
            return data.filter(pl.col(SpecialCols.time_id) == max_time_id + 1).with_columns([
                pl.col(col)
                .alias(f"{col}_avg_same_symbol")
                for col in feat_cols
            ]).with_columns([
                pl.col(col)
                .alias(f"{col}_std_same_symbol")
                for col in feat_cols
            ])
        else:
            data = data.filter(pl.col(SpecialCols.time_id) >= max_time_id - 29)
            return data.group_by([SpecialCols.symbol_id]).agg(
                [pl.col(col).mean().alias(f"{col}_avg_same_symbol") for col in feat_cols]+[
                    pl.col(col).std().alias(f"{col}_std_same_symbol") for col in feat_cols]
            ).with_columns(
                [pl.lit(self.date_id).cast(pl.Int16).alias(SpecialCols.date_id), pl.lit(
                    self.time_id).cast(pl.Int16).alias(SpecialCols.time_id)]
            )
            # return data.with_columns([
            #     pl.col(col).mean().over([SpecialCols.symbol_id])
            #     .alias(f"{col}_avg_same_symbol")
            #     for col in feat_cols
            # ]).with_columns([
            #     pl.col(col).std().over([SpecialCols.symbol_id])
            #     .alias(f"{col}_std_same_symbol")
            #     for col in feat_cols
            # ]).filter((pl.col(SpecialCols.time_id) == self.time_id))

    #
    # @timing_decorator_with_params("hf")
    def ts_mean_std_same_timeid(self, feats: List[str] = None) -> pl.DataFrame:
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 剔除weight, target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        self._update_tsmean_std(self.new_symbols, feats=feat_cols)

        result = []
        # for symbol_id_time_id, obj in self.tsmean_stder.items():
        for symbol_id in self.new_symbols:
            obj = self.tsmean_stder[f"{symbol_id}_{self.time_id}"]
            ids = pl.DataFrame(
                [[self.date_id], [self.time_id], [symbol_id]],
                schema={
                    SpecialCols.date_id: pl.Int16,
                    SpecialCols.time_id: pl.Int16,
                    SpecialCols.symbol_id: pl.Int8,
                },
            )
            mean, std = obj.get_cur_result()
            mean = pl.DataFrame(mean.reshape(1, -1), schema=feat_cols).rename(
                {col: col + "_avg_same_timeid" for col in feat_cols}
            )
            std = pl.DataFrame(std.reshape(1, -1), schema=feat_cols).rename(
                {col: col + "_std_same_timeid" for col in feat_cols}
            )

            # 标识结果
            one_id_result = pl.concat([ids, mean, std], how="horizontal")
            result.append(one_id_result)

        result = pl.concat(result, how="vertical")
        return result
        """
        # 需要30天的数据,但是time_id相同,
        data = data.filter((pl.col(SpecialCols.time_id) == self.time_id))

        data = data.with_columns([pl.col(col).rolling_mean(30).over(
            [SpecialCols.time_id, SpecialCols.symbol_id]).alias(f"{col}_avg_same_timeid") for col in feat_cols])

        data = data.with_columns([pl.col(col).rolling_std(30).over(
            [SpecialCols.time_id, SpecialCols.symbol_id]).alias(f"{col}_std_same_timeid") for col in feat_cols])
        return data.filter((pl.col(SpecialCols.date_id) == self.date_id))

        """

    # v
    # @timing_decorator_with_params("hf")
    def ts_simple_lag(self, feats: List[str] = None) -> pl.DataFrame:
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 获取特征列
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]
        # 只选取需要的数据, 因为是日内lag，所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id) & (
            (pl.col(SpecialCols.time_id) >= self.time_id-2)))

        lags = data.with_columns(
            [
                pl.col(col)
                .shift(1)
                .over(([SpecialCols.date_id, SpecialCols.symbol_id]))
                .alias(col + "_simple_lag_1")
                for col in feat_cols
            ]
        )

        return lags.filter(pl.col(SpecialCols.time_id) == self.time_id)

    # v
    # @timing_decorator_with_params("hf")
    def ts_daily_lag(self, feats: List[str] = None, lag_n_day=1) -> pl.DataFrame:
        """
        ['feature_30_mean_lag_1_d', 'feature_04_mean_lag_1_d', 'feature_05_mean_lag_1_d', 'feature_07_last_lag_1_d', 'feature_08_mean_lag_1_d', 'feature_02_std_lag_1_d', 'feature_16_std_lag_1_d', 'feature_38_mean_lag_1_d', 'feature_04_std_lag_1_d', 'feature_38_last_lag_1_d', 'feature_61_mean_lag_1_d', 'feature_05_last_lag_1_d', 'feature_01_mean_lag_1_d', 'feature_01_std_lag_1_d']
        """
        last_data_file = f"/kaggle/working/data_of_{self.date_id-1}.parquet"
        if self.time_id != 0 or not os.path.exists(last_data_file):
            if self.ts_daily_lag_res is None:
                data = (
                    self._data[SpecialCols.id_cols + feats]
                    if feats is not None
                    else self._data.copy()
                )
                # 剔除weight, target, id_col
                feat_cols = [
                    feat
                    for feat in data.columns
                    if feat not in SpecialCols.target_cols + SpecialCols.id_cols
                ]
                lags_shift = data.filter(
                    (pl.col(SpecialCols.date_id) == self.date_id - lag_n_day)
                )

                # 重设使其变为当前天数
                lags_shift = lags_shift.with_columns(
                    pl.lit(self.date_id).alias(
                        SpecialCols.date_id).cast(pl.Int16)
                )

                # 没有找到agg传str的方法 只有写3个agg函数
                lags_shift = lags_shift.with_columns(
                    [
                        pl.col(col)
                        .mean()
                        .over([SpecialCols.symbol_id])
                        .alias(f"{col}_mean_lag_{lag_n_day}_d")
                        for col in feat_cols
                    ]
                )
                lags_shift = lags_shift.with_columns(
                    [
                        pl.col(col)
                        .std()
                        .over(([SpecialCols.symbol_id]))
                        .alias(f"{col}_std_lag_{lag_n_day}_d")
                        for col in feat_cols
                    ]
                )
                lags_shift = lags_shift.with_columns(
                    [
                        pl.col(col)
                        .last()
                        .over(([SpecialCols.symbol_id]))
                        .alias(f"{col}_last_lag_{lag_n_day}_d")
                        for col in feat_cols
                    ]
                )
                self.ts_daily_lag_res = lags_shift
                return lags_shift.filter(
                    (pl.col(SpecialCols.time_id) == self.time_id))

            return self.ts_daily_lag_res.filter(
                (pl.col(SpecialCols.time_id) == self.time_id))
        data = pl.read_parquet(
            last_data_file, columns=SpecialCols.id_cols + feats)

        last_data = data[[SpecialCols.symbol_id] + feats]
        last = last_data.group_by('symbol_id').tail(1).rename(
            lambda column_name: column_name + "_last_lag_1_d" if column_name != 'symbol_id' else column_name)
        mean = last_data.group_by('symbol_id').mean()
        square_mean = last_data.with_columns([pl.col(
            f) * pl.col(f) for f in last_data.columns if f != 'symbol_id']).group_by('symbol_id').mean()

        std = square_mean.with_columns([pl.col(f) - mean[f] * mean[f] for f in last_data.columns if f != 'symbol_id']).rename(
            lambda column_name: column_name + "_std_lag_1_d" if column_name != 'symbol_id' else column_name)

        mean = mean.rename(lambda column_name: column_name +
                           "_mean_lag_1_d" if column_name != 'symbol_id' else column_name)

        self.ts_daily_lag_res = data.join(
            last.join(mean, on='symbol_id', how='inner').join(
                std, on='symbol_id', how='inner'),
            on='symbol_id', how='left'
        ).with_columns([pl.col(f).cast(pl.Float32) for f in ['feature_30_mean_lag_1_d', 'feature_04_mean_lag_1_d', 'feature_05_mean_lag_1_d', 'feature_08_mean_lag_1_d', 'feature_02_std_lag_1_d', 'feature_16_std_lag_1_d', 'feature_38_mean_lag_1_d', 'feature_04_std_lag_1_d', 'feature_61_mean_lag_1_d', 'feature_01_mean_lag_1_d', 'feature_01_std_lag_1_d']])
        

        return self.ts_daily_lag_res.filter(
            (pl.col(SpecialCols.time_id) == self.time_id))  # 只获取当前的time_id
    # v

    # @timing_decorator_with_params("hf")
    def ts_macross(self, feats: List[str] = None) -> pl.DataFrame:
        """
        目标： 生成时序上MA5和MA15的百分比距离
        1. 对每个特征进行滚动窗口（self._ts_window）取值，如果当前窗口内的nan值比例多于75%,此个feature的当前time_id就算作nan；否则进行后面步骤，算具体的值
        2. 对每个feature在每个time_id计算MA5（5个time_id的平均，包括当前时间点）和MA15
        3. 计算此feature当前time_id内(MA5-MA15)/MA15，算作MA5和MA15的百分比距离
        4. 返回data
        """
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )

        # 剔除weight, target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        # 只选取需要的数据, 因为是日内平均线，所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id) & (
            pl.col(SpecialCols.time_id) >= self.time_id - 20))

        data_ma5 = data.with_columns(
            [
                pl.col(col)
                .rolling_mean(window_size=5)
                .over([SpecialCols.symbol_id], order_by=SpecialCols.time_id)
                for col in feat_cols
            ]
        )  # 计算5日移动平均线，包括当前时间点

        data_ma15 = data.with_columns(
            [
                pl.col(col)
                .rolling_mean(window_size=15)
                .over([SpecialCols.symbol_id], order_by=SpecialCols.time_id)
                for col in feat_cols
            ]
        )  # 计算5日移动平均线，包括当前时间点
        # 时序上MA5和MA15的百分比距离；如果没有MA15就是nan
        data_macross = data_ma5.with_columns(
            [
                ((pl.col(col) - data_ma15[col]) / data_ma15[col]).alias(
                    col + "_macross"
                )
                for col in feat_cols
            ]
        )

        return data_macross.filter((pl.col(SpecialCols.time_id) == self.time_id))

    # @timing_decorator_with_params("hf")
    def ts_percentile(self, feats: List[str] = None) -> pl.DataFrame:
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 剔除weight, target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        # 只选取需要的数据, 因为是日内进行rank_pct，所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id))
        # 数据长度不够，就都是Nan
        if data[SpecialCols.time_id].n_unique() < self._ts_window:
            return data.filter(
                pl.col(SpecialCols.time_id) == self.time_id
            ).with_columns(
                pl.lit(None).cast(pl.Float32).alias(col + "_percentile")
                for col in feat_cols
            )

        pct_rank = (
            data.with_columns(pl.col("time_id").cast(pl.Int32))
            .rolling(
                index_column="time_id",
                group_by=[SpecialCols.date_id, SpecialCols.symbol_id],
                period=f"{self._ts_window}i",
            )
            .agg(
                [
                    (pl.col(col).rank().last() / self._ts_window)
                    .sub(0.5)
                    .alias(col + "_percentile")
                    for col in feat_cols
                ]
            )
        )
        pct_rank = pct_rank.with_columns(pl.col("time_id").cast(pl.Int16))

        return pct_rank.filter(pl.col(SpecialCols.time_id) == self.time_id)

    # @timing_decorator_with_params("hf")
    def ts_diff_percentile(self, feats: List[str] = None) -> pl.DataFrame:
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 剔除weight, target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        # 只选取需要的数据, 因为是日内进行rank_pct，所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id))

        # 计算diff
        data = data.with_columns(
            [
                pl.col(col)
                .diff()
                .over([SpecialCols.date_id, SpecialCols.symbol_id], order_by=SpecialCols.time_id)
                .alias(col + "_diff")
                for col in feat_cols
            ]
        )

        # 数据长度不够，就都是Nan
        if data[SpecialCols.time_id].n_unique() < self._ts_window:
            return data.filter(
                pl.col(SpecialCols.time_id) == self.time_id
            ).with_columns(
                pl.lit(None).cast(pl.Float32).alias(col + "_diff_percentile")
                for col in feat_cols
            )

        # 计算pct_rank
        diff_pct_rank = (
            data.with_columns(pl.col("time_id").cast(pl.Int32))
            .rolling(
                index_column="time_id",
                group_by=[SpecialCols.date_id, SpecialCols.symbol_id],
                period=f"{self._ts_window}i",
            )
            .agg(
                [
                    (pl.col(col + "_diff").rank().last() / self._ts_window)
                    .sub(0.5)
                    .alias(col + "_diff" + "_percentile")
                    for col in feat_cols
                ]
            )
        )
        diff_pct_rank = diff_pct_rank.with_columns(
            pl.col("time_id").cast(pl.Int16))

        return diff_pct_rank.filter(pl.col(SpecialCols.time_id) == self.time_id)

    # @timing_decorator_with_params("hf")
    def ts_zscore(self, feats: List[str] = None) -> pl.DataFrame:
        """
        目标： 时序上进行zscore，计算均值标准差
        1. 对每个特征进行滚动窗口（self._ts_window）取值，如果当前窗口内的nan值多于25%,此个feature的当前time_id就算作nan；否则算具体的值
        2. 对每个feature在每个time_id滚动生成mean和std
        3. 返回data
        """
        data = (
            self._data[SpecialCols.id_cols + feats]
            if feats is not None
            else self._data.copy()
        )
        # 剔除weight, target,id_col
        feat_cols = [
            feat
            for feat in data.columns
            if feat not in SpecialCols.target_cols + SpecialCols.id_cols
        ]

        # 只选取需要的数据, 因为是日内进行, 所以选取一天即可
        data = data.filter((pl.col(SpecialCols.date_id) == self.date_id) & (
            pl.col(SpecialCols.time_id) >= self.time_id - self._ts_window))
        max_time_id = data[SpecialCols.time_id].max()

        if max_time_id < int(self._ts_window * 0.25) - 1:

            return data.filter((pl.col(SpecialCols.date_id) == self.date_id + 1)).with_columns(
                [pl.col(col).alias(col + "_zscore") for col in feat_cols]
            )  # 返回空的
        else:
            data_mean = data.with_columns(
                [
                    pl.col(col).mean()
                    .over([SpecialCols.symbol_id], order_by=SpecialCols.time_id)
                    for col in feat_cols
                ]
            )

            data_std = data.with_columns(
                [
                    pl.col(col).std()
                    .over([SpecialCols.symbol_id], order_by=SpecialCols.time_id)
                    for col in feat_cols
                ]
            )

            data_zscore = data.with_columns(
                ((pl.col(col) - data_mean[col]) /
                 data_std[col]).alias(col + "_zscore")
                for col in feat_cols
            )
            return data_zscore.filter((pl.col(SpecialCols.time_id) == self.time_id))

            # return data_zscore.filter((pl.col(SpecialCols.time_id) == self.time_id))
            # # @timing_decorator_with_params("hf")

    def _update_tsmean_std(self, symbols, feats: List[str] = None) -> None:
        # 更新tsmean_std对象
        for id in symbols:
            if f"{id}_{self.time_id}" not in self.tsmean_stder:
                self.tsmean_stder[f"{id}_{self.time_id}"] = TSMEAN_STD(
                    instrument_id=id,
                    time_id=self.time_id,
                    window=30,
                    features_n=len(feats),
                )
        # 取date_id、time_id时刻的数据
        data = self._data.filter(
            (pl.col(SpecialCols.date_id) == self.date_id)
            & (pl.col(SpecialCols.time_id) == self.time_id)
        )

        # 如果没有数据, 跳过数据更新
        if data.shape[0] == 0:
            return 0

        # 传输数据
        for row in data.iter_slices(n_rows=1):
            id = row[SpecialCols.symbol_id].item()
            # 构建 x 和 y
            one_id_data = data.filter(pl.col(SpecialCols.symbol_id) == id)
            if one_id_data.shape[0] == 0:
                continue
            x = one_id_data[feats].to_numpy().reshape((1, -1))
            self.tsmean_stder[f"{id}_{self.time_id}"].load_new_value(x)

    # @ timing_decorator_with_params("hf")
    def gen_feats(self):
        """
        return
            all_features_data:所用的生成的特征数据
            all_selected_feats:使用的生成的特征列表
        """
        # logger.info("start computing")
        # 所有选取的特征
        all_data = []
        flag = False
        for func, params in HFDataProcessor.function_to_use.items():
            if func in ["ts_mean_std_same_timeid", "ts_regress"]:
                continue
            kwargs = {}
            kwargs["feats"] = list(set(params["func_origin_feats"]))
            if "_gpby_" in func:
                kwargs["groupby_col"] = func.split("_gpby_")[1]
                func = func.split("_gpby_")[0]
            data = getattr(self, func)(**kwargs)[
                SpecialCols.id_cols + params["func_selected_feats"]
            ].sort(SpecialCols.id_cols)
            if flag:
                all_data.append(data[params["func_selected_feats"]])
            else:
                all_data.append(data)
                flag = True

        all_features_data = pl.concat(all_data, how="horizontal")

        return all_features_data

    # @ timing_decorator_with_params("hf")
    def update_original_data(
        self, raw_data: pl.DataFrame, lag1d_responder: pl.DataFrame
    ):
        """
        data: 原始数据带responder(可以填0)
        lag1_responder: column为responder,值滞后一天.
        """
        assert raw_data.shape[1] == 94, "init_data must be original shape"
        self.cur_processed_raw_data = HFDataProcessor.easy_process_raw_data(
            raw_data, lag1d_responder
        )

        input_len = self.cur_processed_raw_data.shape[0]
        # 如果没有数据
        if self._data is None:
            self._data = self.cur_processed_raw_data[-self._keep_length:]
            # self._data_pd = self._trans_raw_data_2pd(
            #     self.cur_processed_raw_data[-self._keep_length:]
            # )
        else:
            assert (
                self.cur_processed_raw_data.shape[1] == self._data.shape[1]
            ), "must be the same shape"
            if input_len + self._data.shape[0] > self._keep_length:
                self._data = pl.concat(
                    [
                        self._data[-(self._keep_length - input_len):],
                        self.cur_processed_raw_data,
                    ],
                    how="vertical",
                )
                # self._data_pd = pd.concat(
                #     [
                #         self._data_pd.iloc[-(self._keep_length - input_len):],
                #         self._trans_raw_data_2pd(raw_data),
                #     ]
                # )
            else:
                self._data = pl.concat(
                    [self._data, self.cur_processed_raw_data], how="vertical"
                )
                # self._data_pd = pd.concat(
                #     [
                #         self._data_pd,
                #         self._trans_raw_data_2pd(self.cur_processed_raw_data),
                #     ]
                # )

        return

    # @ timing_decorator_with_params("hf")
    def update_processed_data(self, processed_data: pl.DataFrame):
        """更新生成的feature"""
        if len(self._cache) >= self._seq_len:
            self._cache.pop(0)
        # self._cache.append(self.cur_processed_raw_data.join(
        #     processed_data, how="left", on=SpecialCols.id_cols))
        self._cache.append(
            processed_data)
        return


    def init_data(self, init_data):
        # 调整date_id
        max_date_id = init_data[SpecialCols.date_id].max()
        init_data = init_data.with_columns(
            pl.col(SpecialCols.date_id) - max_date_id - 1)

        # 构造responder_lag1d 和 row_id 和 is_score
        self._data = init_data.with_columns([pl.col(col).shift(1).over([SpecialCols.symbol_id, SpecialCols.time_id], order_by=SpecialCols.date_id).alias(col+"_lag1d") for col in SpecialCols.target_cols]).with_columns(
            [pl.lit(False).cast(pl.Boolean).alias("is_scored")]
        ).with_columns((pl.lit(0)).alias("row_id").cast(pl.UInt32))[self.col_orders]

    # @ timing_decorator_with_params("hf")
    def update(
        self, date_id, time_id, raw_data: pl.DataFrame, lag1d_responder: pl.DataFrame
    ):
        """
        update是对外函数，更新最新数据，每一期都调用
        """
        """temp_data 只有一个 time_id"""
        self.date_id = date_id
        self.time_id = time_id
        if time_id == 0:
            # lags处理
            lag1d_responder.columns = [col.replace('_lag_1', '')
                                       for col in lag1d_responder.columns]
            self._temp_lag = lag1d_responder

        # 日间数据特征计算有关

        self.new_symbols = set(raw_data[SpecialCols.symbol_id].unique())
        self.update_original_data(raw_data, self._temp_lag.filter(
            (pl.col(SpecialCols.time_id) == self.time_id)))
        # 计算pca数据
        self.pca_data = self.add_pca(self.cur_processed_raw_data.fill_null(0))

        processed_data = (
            self.gen_feats()
        )

        # concat pca数据
        processed_data = self.pca_data.join(processed_data, on = SpecialCols.id_cols,how='left')

        # 保存的是不fillna或者做其他譬如log之类的更改的
        self.update_processed_data(processed_data)

        # 日间数据特征计算有关
        self.symbols = self.new_symbols

    # @ timing_decorator_with_params("hf")
    def get(self) -> pl.DataFrame:
        """
        get是对外函数,返回模型训练用的数据,只在特定需要重训练的时候调用
        """
        return_data = pl.concat(self._cache, how='vertical').fill_null(0)
        return return_data

In [7]:
# init_data = pl.read_parquet(os.path.join(CONFIG.path, "train.parquet", "partition_id=9", "part-0.parquet"))
# dp = HFDataProcessor()
# dp.init_data(init_data)
# del init_data
# gc.collect()

# last_date, date_cache = None, []
# test_path = "/kaggle/input/aft-js-dataset/synthetic_test.parquet_short"
# lags_path = "/kaggle/input/aft-js-dataset/synthetic_lag.parquet_short"
# for test_data, _ in tqdm(generate_data_batches(test_path, lags_path)):
#     test, lags = test_data
#     test = test.with_columns([pl.lit(0).alias(target_col).cast(pl.Float32) for target_col in SpecialCols.target_cols])  # 必须传这个
#     date_id, time_id = test['date_id'][0], test['time_id'][0]
#     if last_date is None: 
#         last_date = date_id
#     if last_date != date_id and lags is not None:  # 新的一天更新
#         _lags = lags.with_columns((pl.col('date_id') - 1).alias('date_id'))
#         _lags = _lags.rename({col: col.replace('_lag_1', '') for col in _lags.columns if '_lag_1' in col})
#         last_date_all_data = pl.concat(date_cache) # 不含target
#         last_date_all_data = last_date_all_data.join(_lags, on=SpecialCols.id_cols, how='left')
#         last_date_all_data.write_parquet(f"/kaggle/working/data_of_{last_date}.parquet")
#         date_cache = []
#     dp.update(date_id, time_id, test, lags)
#     one_time_data = dp.get()  # get 出来不含target
#     # 保存数据
#     newly_updated = one_time_data.filter(pl.col("time_id") == time_id).filter(pl.col("date_id") == date_id)  # 不含target
#     date_cache.append(newly_updated) # 不含target
#     last_date = date_id
#     ####
#     #### dataloader
#     one_time_data = one_time_data.with_columns([pl.lit(0).alias(target_col).cast(pl.Float32) for target_col in SpecialCols.target_cols]).drop(['row_id', 'is_scored'])
#     forecast_loader = get_dataloader(one_time_data, step_len=16, get_last_batch=True)
#     for combined_batch in forecast_loader:
#         combined_batch = combined_batch.float()
#         X = combined_batch[:,:,:,:]  # 最后一个Batch的取值
#         break

In [8]:
# data = dp.get()
# data.head()

In [9]:
# data = pl.read_parquet("/kaggle/input/aft-js-dataset/data_of_0.parquet")

In [10]:
# import pickle
# with open("/kaggle/working/data_columns.pkl", 'wb') as f:
#     pickle.dump(data.columns,f)

**model definition**

trainer

In [11]:
class EarlyStopper:
    def __init__(
        self,
        patience=5,
        delta=0,
    ):
        self.patience = patience
        self.best_value = None
        self.count = 0
        self.best_model = None
        self.delta = delta
        self.is_earlystop = False

    def earlystop(self, value: float, model):
        if self.best_value is None:
            self.best_value = value
            self.best_model = deepcopy(model.state_dict())
        elif (self.best_value - value) > self.delta:
            self.best_value = value
            self.best_model = deepcopy(model.state_dict())
            self.count = 0
        else:
            self.count += 1
            print("EarlyStop Counter: {:02d}".format(self.count))
            if self.count >= self.patience:
                self.is_earlystop = True
                print("# EarlyStop!")


class IncrementalTrainer:

    def __init__(self, model, batch_size=242, lr=0.001, patience=5, epoch=100):
        self.input_dim = 237
        self.patience = patience
        self.epoch = epoch
        self.model = model
        self.early_stop = EarlyStopper(patience=self.patience)
        self.device = (
            torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        )
        self.model.to(self.device)
        self.target_pos = -3
        self.weight_pos = -10
        self.batch_size = batch_size
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = self.loss_fn

    def loss_fn(self, pred, label, weight=None):
        """
        pred: [B*N, 1]
        label, weight: [B, N]
        """
        # mse
        pred = pred.view(-1, 1)
        label = label.view(-1, 1)
        weight = weight.view(-1, 1)
        if weight is not None:
            loss = weight * (pred - label) ** 2
        else:
            loss = (pred - label) ** 2
        return torch.mean(loss)

    def metric_fn(self, pred, label, weight):
        pred = pred.view(-1, 1)
        label = label.view(-1, 1)
        weight = weight.view(-1, 1)
        # weighted mse, weighted r2
        # r2 = 1 - sum(w * (y - y_hat) ** 2) / sum(w * y ** 2)
        label = torch.clamp(label, -5, 5)
        wmse = weight * (pred - label) ** 2
        wr2 = 1 - torch.sum(wmse) / torch.sum(weight * label**2)
        return torch.mean(wmse), wr2

    def _train_epoch(self, data_loader):
        self.model.train()
        for combined_batch in tqdm(
            batch_combine(data_loader, batch_size=self.batch_size, drop_last=True),
            desc="train epoch",
        ):
            combined_batch = combined_batch.float()
            X, weight, y = (
                combined_batch[:, :, :, : self.input_dim],
                combined_batch[:, :, -1, self.weight_pos],
                combined_batch[:, :, -1, self.target_pos],
            )
            label = y.to(self.device)
            weight = weight.to(self.device)
            preds = self.model.forward(X.to(self.device))
            loss = self.criterion(preds, label, weight)
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(self.model.parameters(), 3.0)  # 梯度裁剪
            self.optimizer.step()

    def _test_epoch(self, data_loader):
        self.model.eval()
        with torch.no_grad():
            loss_list = []
            valid_loss_list = []
            metric_list = []
            for combined_batch in tqdm(
                batch_combine(data_loader, batch_size=self.batch_size, drop_last=True),
                desc="test epoch",
            ):
                combined_batch = combined_batch.float()
                X, weight, y = (
                    combined_batch[:, :, :, : self.input_dim],
                    combined_batch[:, :, -1, self.weight_pos],
                    combined_batch[:, :, -1, self.target_pos],
                )
                label = y.to(self.device)
                weight = weight.to(self.device)
                preds = self.model.forward(X.to(self.device))
                loss = self.criterion(preds, label, weight)
                metric = self.metric_fn(preds, label, weight)
                loss_list.append(loss.item())
                valid_loss_list.append(metric[0].item())
                metric_list.append(metric[1].item())
        return np.mean(loss_list), np.mean(valid_loss_list), np.mean(metric_list)

    def incremental_update(self, train_loader, valid_loader):
        # 这里原本想优先test一次,但是效果不怎么好
        for epoch in range(self.epoch):
            self._train_epoch(train_loader)
            mean_loss = self._test_epoch(valid_loader)
            self.early_stop.earlystop(-mean_loss[2], self.model)  # 这里需要max wr2
            if self.early_stop.is_earlystop:
                break
            print(f"epoch:{epoch}, loss:{mean_loss}")
        self.model.load_state_dict(self.early_stop.best_model)
        self.early_stop = EarlyStopper(patience=self.patience)

    def load_model(self, path):
        print("load model at %s" % path)
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.fitted = True

    def save_model(self, save_path):
        torch.save(self.model.state_dict(), save_path)

**model training**

gru

In [12]:
class RNNModel(nn.Module):

    def __init__(
        self,
        d_feat,
        hidden_size,
        num_layers,
        dropout,
        layer_type: str,
        output_size: int,
    ):
        super().__init__()
        assert layer_type.upper() in ["RNN", "GRU", "LSTM"], (
            "Unexpected layer name: %s" % layer_type
        )
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = getattr(nn, layer_type.upper())(
            input_size=d_feat,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, X):
        B, N, S, F = X.shape
        X = X.view(-1, S, F)
        out, _ = self.rnn(X)
        out = self.linear(out[:, -1, :])
        return out

model = RNNModel(237, 128, 2, 0.2, "gru", 1)
gru_trainer = IncrementalTrainer(model, epoch=1)
gru_trainer.load_model("/kaggle/input/aft-js-dataset/gru_model_state.pth")

load model at /kaggle/input/aft-js-dataset/gru_model_state.pth


transformer

In [13]:
import torch.nn.functional as F
import copy, math


def clones(module, N):
    """
    Produce N identical layers.
    # 实现一个网络的深copy，一个新的对象和原来的对象完全分离，不分享任何存储空间
    # 从而保证可训练参数，都有自己的取值，梯度
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class attention(nn.Module):
    """
    Scaled Dot-Product Attention
    """

    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        # [batch_size, num_head, sequence_len, d_k], batch_size维度是不参与计算的
        d_k = query.size(-1)
        # KQ点积代表相似度，除以sqrt(d_k)=4防止过大
        # [batch_size, num_head, sequence_len, sequence_len]
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        # 对最后一个维度softmax，作为权重
        p_attn = F.softmax(scores, dim=-1)
        # 加权 [batch_size, num_head, sequence_len, d_k]
        output = torch.matmul(p_attn, value)
        return output, p_attn  # 返回 p_attn(64,8,30,30)，只是为了可视化多头注意力机制


class MultiHeadedAttention(nn.Module):
    def __init__(self, num_head, d_model):
        # num_head 多头数量
        super().__init__()
        self.d_model = d_model
        self.num_head = num_head
        assert d_model % num_head == 0  # We assume d_v always equals d_k
        self.d_k = self.d_model // self.num_head
        self.d_v = self.d_model // self.num_head

        # 先通过一层linear，然后降维
        self.W_Q = nn.Linear(self.d_model, self.d_k * self.num_head, bias=False)
        self.W_K = nn.Linear(self.d_model, self.d_k * self.num_head, bias=False)
        self.W_V = nn.Linear(self.d_model, self.d_v * self.num_head, bias=False)
        self.fc = nn.Linear(self.num_head * self.d_v, self.d_model, bias=False)

        self.attention_ = attention()
        self.attn = None

    def forward(self, query, key, value):
        # K Q V = [batch_size, sequence_len, d_model]
        nbatches = query.size(0)
        # K Q V = [batch_size, n_heads, sequence_len, d_k]
        query = (
            self.W_Q(query).view(nbatches, -1, self.num_head, self.d_k).transpose(1, 2)
        )
        key = self.W_K(key).view(nbatches, -1, self.num_head, self.d_k).transpose(1, 2)
        value = (
            self.W_V(value).view(nbatches, -1, self.num_head, self.d_v).transpose(1, 2)
        )
        # x = [batch_size, n_heads, sequence_len, d_k]
        x, self.attn = self.attention_(query, key, value)
        # 维度转换 x = [batch_size, sequence_len, d_k*n_heads]
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.num_head * self.d_k)
        # 线性层变回 [batch_size, sequence_len, d_model]
        output = self.fc(x)
        return output


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    # 全连接网络，包含两个线性变换(注意带有常数偏离项)和一个非线性函数(如ReLU)
    # linear -> relu -> linear
    def __init__(self, d_model, d_ff):
        # [batch_size, sequence_len, d_model]
        super(PositionwiseFeedForward, self).__init__()
        # [batch_size, sequence_len, d_ff]
        self.w_1 = nn.Linear(d_model, d_ff)
        # [batch_size, sequence_len, d_model]
        self.w_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.w_1(x)
        x = F.relu(x)
        x = self.w_2(x)
        return x


class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)  # size=d_model=512
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs, sublayer):
        "Apply residual connection to any sublayer with the same size."
        # x [batch_size, sequence_len, d_model]
        # sublayer是一个具体的MultiHeadAttention 或 PositionwiseFeedForward对象
        x = self.norm(inputs)
        x = sublayer(x)
        x = inputs + self.dropout(x)
        return self.norm(x)


class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)  # 深度克隆两个
        self.size = size  # 128

    def forward(self, x):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x))
        return self.sublayer[1](x, self.feed_forward)


class Encoder(nn.Module):
    def __init__(self, layer, num_layers):
        """
        Encoder is a stack of N layers
        layer: instance of EncoderLayer
        """
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([layer for _ in range(num_layers)])
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x):
        # x = [batch_size, sequence_len, d_model]
        for i, layer in enumerate(self.layers):
            x = layer(x)
        # [batch_size, sequence_len, d_model]
        return self.norm(x)


class Transformer(nn.Module):
    """
    Transformer's Encoder part. Base for this and many other models.
    """

    def __init__(
        self,
        num_layers,
        d_feat,
        d_model,
        d_ff,
        num_heads,
        output_size,
        dropout,
    ):
        """
        task_type: "binary", "multiclass", "regression"
        """
        super(Transformer, self).__init__()
        self.num_layers = num_layers
        self.d_feat = d_feat
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.output_size = output_size
        self.dropout = dropout
        self.embeding = nn.Linear(self.d_feat, self.d_model)
        self.attn = MultiHeadedAttention(self.num_heads, self.d_model)
        self.ff = PositionwiseFeedForward(self.d_model, self.d_ff)
        self.encoderlayer = EncoderLayer(self.d_model, self.attn, self.ff, self.dropout)
        self.encoder = Encoder(self.encoderlayer, self.num_layers)
        self.init_weights()  # normal initialization
        self.outputs = nn.Linear(self.d_model, self.output_size)

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.trunc_normal_(p, mean=0, std=0.01, a=-0.1, b=0.1)

    def forward(self, src):
        B, N = src.shape[0], src.shape[1]
        src = src[:, :, -1, :]  # 截面
        src = self.embeding(src)
        src = self.encoder(src)
        x_hidden = src.reshape(B, N, -1)
        res = self.outputs(x_hidden)
        return res

model = Transformer(3, 237, 128, 128, 16, 1, 0.5)
transformer_trainer = IncrementalTrainer(model, epoch=1)
transformer_trainer.load_model("/kaggle/input/aft-js-dataset/transformer_model_state_3_16_05.pth")

load model at /kaggle/input/aft-js-dataset/transformer_model_state_3_16_05.pth


mlp

In [14]:
# class MYMLP(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super().__init__()
#         # Time-series MLP
#         self.time_series_mlp = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU()
#         )

#         # Cross-sectional MLP
#         self.cross_sectional_mlp = nn.Sequential(
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )

#     def forward(self, data):
#         """
#         [batch_size, cross_number, time_steps, feature_size] -> (batch_size x cross_number, output_dim)
#         """
#         batch_size, cross_number, time_steps, feature_size = data.shape

#         # Process time-series information
#         data = data.view(batch_size * cross_number, time_steps, feature_size)
#         time_features = self.time_series_mlp(data.view(-1, feature_size))  # [(batch_size * cross_number * time_steps) x hidden_dim]
#         time_features = time_features.view(batch_size * cross_number, time_steps, -1).mean(dim=1)  # Aggregate across time steps

#         # Reshape back to cross-sectional structure
#         time_features = time_features.view(batch_size, cross_number, -1)  # Shape: [batch_size, cross_number, hidden_dim]

#         # Process cross-sectional features
#         cross_features = self.cross_sectional_mlp(time_features)  # Shape: [batch_size, cross_number, output_dim]
#         return cross_features
# # Model configuration
# input_dim = input_dim
# hidden_dim = 32
# output_dim = 1
# model = MYMLP(input_dim, hidden_dim, output_dim)
# # trainer
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# trainer = IncrementalTrainer(model, optimizer, criterion)
# # training
# step_len = 16
# batch_size = 242
# start = time.time()
# exist_data = pl.read_parquet("/kaggle/input/aft-js-dataset/data_of_0.parquet").drop(['row_id', 'is_scored'])
# # # 计算每列的均值和标准差
# # all_feature = list(set(exist_data.columns)-set([SpecialCols.weight_col]+SpecialCols.target_cols+SpecialCols.id_cols))
# # features = exist_data[all_feature].to_numpy()
# # means, stds = np.mean(features, axis=0), np.std(features, axis=0)
# # stds = np.where(stds < 1e-6, 1e-6, stds)
# # exist_data = standardize_polars(exist_data, means, stds)
# ###
# train_loader = get_dataloader(exist_data, step_len)
# test_loader = get_dataloader(exist_data, step_len)
# trainer.incremental_update(train_loader, test_loader)
# end = time.time()
# print(end-start)

MASTER

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:x.shape[2], :].unsqueeze(0).unsqueeze(0)


class SAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.temperature = math.sqrt(self.d_model/nhead)

        self.qtrans = nn.Linear(d_model, d_model, bias=False)
        self.ktrans = nn.Linear(d_model, d_model, bias=False)
        self.vtrans = nn.Linear(d_model, d_model, bias=False)

        attn_dropout_layer = []
        for i in range(nhead):
            attn_dropout_layer.append(Dropout(p=dropout))
        self.attn_dropout = nn.ModuleList(attn_dropout_layer)

        # input LayerNorm
        self.norm1 = LayerNorm(d_model, eps=1e-5)
        # self.norm1 = BatchNorm1d(d_model, eps=1e-5)


        # FFN layerNorm
        self.norm2 = LayerNorm(d_model, eps=1e-5)
        # self.norm2 = BatchNorm1d(d_model, eps=1e-5)

        self.ffn = nn.Sequential(
            Linear(d_model, d_model),
            nn.ReLU(),
            Dropout(p=dropout),
            Linear(d_model, d_model),
            Dropout(p=dropout)
        )

    def forward(self, x):
        batch, N, T, D = x.shape
        x = x.permute(0, 2, 1, 3).contiguous().view(-1, N, D) # [batch, N, T, D] --> [batch, T, N, D]
        #x = x.permute(0, 2, 3, 1).contiguous().view(-1, D, N) # [batch, N, T, D] --> [batch, T, D, N] --> [batch * T, D, N]

        x = self.norm1(x)
        q = self.qtrans(x)
        k = self.ktrans(x)
        v = self.vtrans(x)

        dim = int(self.d_model/self.nhead)
        att_output = []
        for i in range(self.nhead):
            if i==self.nhead-1:
                qh = q[:, :, i * dim:]
                kh = k[:, :, i * dim:]
                vh = v[:, :, i * dim:]
            else:
                qh = q[:, :, i * dim:(i + 1) * dim]
                kh = k[:, :, i * dim:(i + 1) * dim]
                vh = v[:, :, i * dim:(i + 1) * dim]

            atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)) / self.temperature, dim=-1) # [batch * T, N, D] * [batch * T, D, N] --> [batch * T, N, N]
            if self.attn_dropout:
                atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh) 
            att_output.append(torch.matmul(atten_ave_matrixh, vh)) # [batch *T, N, N] * [batch *T, N, D] --> [batch *T, N, D]
        att_output = torch.concat(att_output, dim=-1)

        # FFN
        xt = x + att_output
        xt = self.norm2(xt)
        att_output = xt + self.ffn(xt)

        return att_output.view(batch, T, N, D).permute(0, 2, 1, 3).contiguous() # [batch * T, N, D] --> [batch, T, N, D] --> [batch, N, T, D]


class TAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.qtrans = nn.Linear(d_model, d_model, bias=False)
        self.ktrans = nn.Linear(d_model, d_model, bias=False)
        self.vtrans = nn.Linear(d_model, d_model, bias=False)

        self.attn_dropout = []
        if dropout > 0:
            for i in range(nhead):
                self.attn_dropout.append(Dropout(p=dropout))
            self.attn_dropout = nn.ModuleList(self.attn_dropout)

        # input LayerNorm
        self.norm1 = LayerNorm(d_model, eps=1e-5)
        # self.norm1 = BatchNorm1d(d_model, eps=1e-5)

        # FFN layerNorm
        self.norm2 = LayerNorm(d_model, eps=1e-5)
        # self.norm2 = BatchNorm1d(d_model, eps=1e-5)

        # FFN
        self.ffn = nn.Sequential(
            Linear(d_model, d_model),
            nn.ReLU(),
            Dropout(p=dropout),
            Linear(d_model, d_model),
            Dropout(p=dropout)
        )

    def forward(self, x):
        batch, N, T, D = x.shape
        x = x.view(-1, T, D)
        x = self.norm1(x)
        q = self.qtrans(x)
        k = self.ktrans(x)
        v = self.vtrans(x)

        dim = int(self.d_model / self.nhead)
        att_output = []
        for i in range(self.nhead):
            if i==self.nhead-1:
                qh = q[:, :, i * dim:]
                kh = k[:, :, i * dim:]
                vh = v[:, :, i * dim:]
            else:
                qh = q[:, :, i * dim:(i + 1) * dim]
                kh = k[:, :, i * dim:(i + 1) * dim]
                vh = v[:, :, i * dim:(i + 1) * dim]
            atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1) # [N, T, dim] * [N, dim, T] --> [N, T, T]
            if self.attn_dropout:
                atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh)
            att_output.append(torch.matmul(atten_ave_matrixh, vh)) # [N, T, T] * [N, T, dim] --> [N, T, dim]
        att_output = torch.concat(att_output, dim=-1) # [N, T, D]

        # FFN
        xt = x + att_output
        xt = self.norm2(xt)
        att_output = xt + self.ffn(xt)

        return att_output.view(batch, N, T, D)


class TemporalAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.trans = nn.Linear(d_model, d_model, bias=False)

    def forward(self, z):
        batch, N, T, D = z.shape
        z = z.view(-1, T, D)
        h = self.trans(z) # [batch * N, T, D]
        query = h[:, -1, :].unsqueeze(-1)
        lam = torch.matmul(h, query).squeeze(-1)  # [batch * N, T, D] --> [batch * N, T]
        lam = torch.softmax(lam, dim=1).unsqueeze(1) # [batch * N, T] --> [batch * N, 1, T]
        output = torch.matmul(lam, z).squeeze(1)  # [batch * N, 1, T], [batch * N, T, D] --> [batch * N, 1, D] --> [batch * N, D]
        return output.view(batch, N, D)

class Gate(nn.Module):
    def __init__(self, d_input, d_output, beta=1.0):
        super().__init__()
        self.trans = nn.Linear(d_input, d_output)
        self.d_output =d_output
        self.t = beta

    def forward(self, gate_input):
        output = self.trans(gate_input)
        output = torch.softmax(output/self.t, dim=-1)
        return self.d_output*output

class MASTER(nn.Module):
    def __init__(self, d_feat=76, d_model=256, t_nhead=2, s_nhead=2, T_dropout_rate=0.5, S_dropout_rate=0.5,
                 gate_input_start_index=None, gate_input_end_index=None, beta=None):
        super(MASTER, self).__init__()
        #self.gate_input_start_index = gate_input_start_index
        #self.gate_input_end_index = gate_input_end_index
        #self.d_gate_input = gate_input_end_index - gate_input_start_index
        #self.feature_gate = Gate(self.d_gate_input, d_feat, beta=beta)
        
        self.layers = nn.Sequential(
            # feature layer
            nn.Linear(d_feat, d_model),
            PositionalEncoding(d_model),
            # intra-stock aggregation
            TAttention(d_model=d_model, nhead=t_nhead, dropout=T_dropout_rate),
            # inter-stock aggregation
            SAttention(d_model=d_model, nhead=s_nhead, dropout=S_dropout_rate),
            TemporalAttention(d_model=d_model),
            # decoder
            nn.Linear(d_model, 1),
            # nn.Tanh()
        )

    def forward(self, x):
        src = x
        #src = x[:, :, :, :self.gate_input_start_index]
        #gate_input = x[:, :, -1, self.gate_input_start_index:]
        #src = src * torch.unsqueeze(self.feature_gate(gate_input), dim=2)
        output = self.layers(src).squeeze(-1)
        # clip to [-5, 5]
        # output = output - torch.mean(output, dim=1, keepdim=True)
        output = torch.clamp(output, -5.0, 5.0)
        
        return output

lgbm

In [15]:
import lightgbm as lgb

lgb_model =lgb.Booster(model_file="/kaggle/input/aft-js-dataset/lgb_1_14.lgb")



In [16]:
all_feature = pd.read_pickle("/kaggle/input/aft-js-dataset/data_columns.pkl")[6:-9]
means, stds = pd.read_pickle("/kaggle/input/aft-js-dataset/mean_std.pkl")
def standardize_polars(data: pl.DataFrame, means=means, stds=stds):
    features = data[all_feature].to_numpy()
    data_standardized = np.clip((features - means) / stds, -5, 5)
    data = data.with_columns(pl.Series(data_standardized[:, i]).alias(f) for i, f in enumerate(all_feature))
    return data

**初始模型和dp都定义好后，开始online training**

In [17]:
# def get_one_batch(data, sequence_len, feature_dim):
#     symbols = data.filter(pl.col("time_id") == data['time_id'].max())['symbol_id'].to_list()
#     output_array = np.zeros((1, len(symbols), sequence_len, feature_dim))
#     for i, sym in enumerate(symbols):
#         symbol_data = data.filter(pl.col("symbol_id") == sym).tail(sequence_len).drop(SpecialCols.id_cols+['weight'])
#         output_array[0, i, -len(symbol_data):, :] = symbol_data.to_numpy()[:, :]
#     return output_array

In [18]:
class OnlineTraining:
    def __init__ (self, data_processor: HFDataProcessor, models: List):
        """
        models: 模型列表 也可以是树模型
        """
        self.dp = data_processor
        self.lags_ = None  # 滞后responder
        
        self.models = models
        self.retrain_days = 7 # 累计多少天数据更新一次 至少两天
        self.days_passed = 0
        self.last_date = None
        self.date_cache = []

        self.step_len = 16  # 固定的
        self.input_dim = 237  # 特征维度
        self.responder6_pos = -3  # responder6的位置

        self.pbar = tqdm(desc="test prediction")

        self.test_parquet = '/kaggle/input/aft-js-dataset/synthetic_test.parquet_short'
        self.lag_parquet = '/kaggle/input/aft-js-dataset/synthetic_lag.parquet_short'
        

    def run_inference_server(self):
        inference_server = js_server.JSInferenceServer(self.predict)
        if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
            inference_server.serve()
        else:
            inference_server.run_local_gateway((self.test_parquet, self.lag_parquet))

    
    def predict(self, test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
        predictions = test.select(['row_id'])
        is_score = test['is_scored'].any()
        if not is_score:  # 全都不需要预测
            predictions = predictions.with_columns(pl.lit(0).alias('responder_6'))
            return predictions

        date_id, time_id = test['date_id'][0], test['time_id'][0]
        
        if self.last_date is None: 
            self.last_date = date_id
        
        # 增量训练
        if self.last_date != date_id and lags is not None:  # 新的一天更新
            start = time.time()
            _lags = lags.with_columns(pl.col('date_id') - 1)
            _lags = _lags.rename({col: col.replace('_lag_1', '') for col in _lags.columns if '_lag_1' in col})
            last_date_all_data = pl.concat(self.date_cache)  # .with_columns([pl.col(f).cast(_lags[f].dtype) for f in SpecialCols.id_cols])
            last_date_all_data = last_date_all_data.join(_lags, on=SpecialCols.id_cols, how='left')
            last_date_all_data.write_parquet(f"/kaggle/working/data_of_{self.last_date}.parquet")
            self.date_cache = []
            self.days_passed += 1
            if self.days_passed >= self.retrain_days:  # 达到重训练天数
                print("start retraining")
                self.days_passed = 0
                # 增量训练
                too_old_file = f"/kaggle/working/data_of_{date_id-self.retrain_days-1}.parquet"
                if os.path.exists(too_old_file): # 移除早期文件, 避免/kaggle/working/文件夹爆了
                    os.remove(too_old_file)
                retrain_data_train = pl.concat([pl.read_parquet(f"/kaggle/working/data_of_{date_id-d}.parquet") for d in range(2, 1+self.retrain_days)], how='vertical').drop(['row_id', 'is_scored'])
                retrain_data_valid = pl.read_parquet(f"/kaggle/working/data_of_{date_id-1}.parquet").drop(['row_id', 'is_scored'])
                ### standardized
                retrain_data_train = standardize_polars(retrain_data_train)
                retrain_data_valid = standardize_polars(retrain_data_valid)
                ################
                train_loader = get_dataloader(retrain_data_train, self.step_len)
                test_loader = get_dataloader(retrain_data_valid, self.step_len)
                for trainer in self.models:
                    if isinstance(trainer, IncrementalTrainer):
                        trainer.incremental_update(train_loader, test_loader)
                gc.collect()
            end = time.time()
            print(end-start)

        # 数据更新
        test = test.with_columns([pl.lit(0).alias(f) for f in SpecialCols.target_cols])
        self.dp.update(date_id, time_id, test, lags)  # 更新最新数据
        one_time_data = self.dp.get()  # 不含target
        newly_updated = one_time_data.filter(pl.col("time_id") == time_id).filter(pl.col("date_id") == date_id)
        self.date_cache.append(newly_updated)
        self.last_date = date_id

        # 数据预测
        one_time_data = one_time_data.with_columns([pl.lit(0).alias(t) for t in SpecialCols.target_cols]).drop(['is_scored', 'row_id'])
        one_time_data = standardize_polars(one_time_data)
        forecast_loader = get_dataloader(one_time_data, step_len=self.step_len, get_last_batch=True)
        for combined_batch in forecast_loader:
            combined_batch = combined_batch.float()
            X = combined_batch[:,:,:,:self.input_dim]  # 最后一个Batch的取值
            break
            
        # # get one batch
        # symbols = newly_updated['symbol_id'].to_list()
        # X = np.zeros((1, len(symbols), self.step_len, self.input_dim))
        # for i, sym in enumerate(symbols):
        #     symbol_data = one_time_data.filter(pl.col("symbol_id") == sym).tail(self.step_len).drop(SpecialCols.id_cols).to_numpy()
        #     X[0, i, -len(symbol_data):, :] = symbol_data
        # X = torch.tensor(X).float()
            
        cnt, res = 0, 0
        for trainer in self.models:  # # ensemble
            if isinstance(trainer, IncrementalTrainer):
                preds = trainer.model(X.to(trainer.device)) # [1, cross_number, output_dim]
            else:
                # 树模型
                preds = trainer.predict(X[0,:,-1,:self.input_dim].detach().numpy())
                
            if isinstance(preds, torch.Tensor):
                if 'cpu' not in str(trainer.device):
                    preds = preds.cpu()
                preds = preds.detach().numpy().reshape(-1)
            elif isinstance(preds, (pd.DataFrame, pd.Series)):
                preds = preds.values.reshape(-1)
            else:
                preds = preds.reshape(-1)
                
            res += preds
            cnt += 1
        preds = res / cnt

        # 结果返回
        try:
            predictions = predictions.with_columns(pl.Series(np.clip(preds, a_min = -5, a_max = 5)).alias('responder_6'))
        except:
            print("shape cannot match.")
            predictions = predictions.with_columns(pl.lit(0).alias('responder_6'))
        self.pbar.update(1)
        return predictions

In [19]:
# init_data = pl.read_parquet(os.path.join(CONFIG.path, "train.parquet", "partition_id=9", "part-0.parquet"))
dp = HFDataProcessor()
# dp.init_data(init_data)
# del init_data
# gc.collect()
js_predictor = OnlineTraining(dp, [gru_trainer, transformer_trainer, lgb_model])
js_predictor.run_inference_server()

test prediction: 969it [02:38,  1.97it/s]

1.0060441493988037


test prediction: 1937it [05:18,  2.13it/s]

0.9282479286193848


test prediction: 2904it [08:00,  6.05it/s]

0.9317927360534668


test prediction: 3872it [10:48,  5.74it/s]