In [14]:


from src.utils.path import PathManager
from src.config.app_config import AppConfig
from src.observability.instrumentation import Instrumentation
from src.pipeline.step import BasePipelineStep
from src.pipeline.context import PipelineContext
from src.utils.filesystem import FileSystem
from src.utils.logger import logs

In [2]:
cfg = AppConfig.load()
pm = PathManager()
inst = Instrumentation()
d = '2015-01-01'

raw_dir = pm.raw_dir(d)
parquet_dir = pm.parquet_dir(d)
symbol_dir = pm.symbol_dir(d)
normalize_dir = pm.canonical_dir(d)
meta_dir = pm.meta_dir(d)

FileSystem.ensure_dir(raw_dir)
FileSystem.ensure_dir(parquet_dir)
FileSystem.ensure_dir(symbol_dir)
FileSystem.ensure_dir(normalize_dir)
FileSystem.ensure_dir(meta_dir)

ctx = PipelineContext(
    date=d,
    raw_dir=raw_dir,
    parquet_dir=parquet_dir,
    symbol_dir=symbol_dir,
    canonical_dir=normalize_dir,
    meta_dir=meta_dir
)


In [3]:
class CsvConvertStep(BasePipelineStep):
    def __init__(self, engine, inst=None):
        super().__init__(inst)
        self.engine = engine

    def run(self, ctx: PipelineContext):
        input_dir = ctx.raw_dir
        out_dir = ctx.parquet_dir

        for zfile in input_dir.glob("*.7z"):
            out_files = self._build_out_files(zfile, out_dir)
            if self._all_exist(out_files):
                print(f"[CsvConvertStep] skip {zfile.name}")
                continue
            print(f"[CsvConvertStep]  {zfile.name} {out_files}")
            self.engine.convert(zfile, out_files)

    def _detect_type(self, filename):
        """
        根据文件名约定识别 file_type：
            SH_Stock_OrderTrade.csv.7z → SH_MIXED
            SH_Order.csv.7z           → SH_ORDER
            SH_Trade.csv.7z           → SH_TRADE
            SZ_Order.csv.7z           → SZ_ORDER
            SZ_Trade.csv.7z           → SZ_TRADE
        """
        lower = filename.lower()

        if lower.startswith("sh_stock_ordertrade"):
            return "SH_MIXED"

        if lower.startswith("sh_order"):
            return "SH_ORDER"
        if lower.startswith("sh_trade"):
            return "SH_TRADE"

        if lower.startswith("sz_order"):
            return "SZ_ORDER"
        if lower.startswith("sz_trade"):
            return "SZ_TRADE"

        raise RuntimeError(f"无法识别文件类型: {filename}")

    def _build_out_files(self, zfile: Path, parquet_dir: Path) -> dict[str, Path]:
        file_type = self._detect_type(zfile.stem)
        if file_type == "SH_MIXED":
            return {
                "sh_order": parquet_dir / "sh_order.parquet",
                "sh_trade": parquet_dir / "sh_trade.parquet",
            }

        stem = zfile.stem.replace(".csv", "")
        return {
            stem.lower(): parquet_dir / f"{stem.lower()}.parquet"
        }

    @staticmethod
    def _all_exist(out_files: dict[str, Path]) -> bool:
        return all(p.exists() for p in out_files.values())


from pathlib import Path
from src.engines.extractor_engine import ExtractorEngine


class ParquetAppendWriter:
    def __init__(self):
        self._schemas: dict[Path, pa.Schema] = {}
        self._writers: dict[Path, pq.ParquetWriter] = {}

    def write_batches(self, path: Path, batches: list[pa.RecordBatch]) -> None:
        if not batches:
            return

        writer = self._writers.get(path)
        if writer is None:
            # schema 来自第一个 batch
            schema = batches[0].schema
            path.parent.mkdir(parents=True, exist_ok=True)
            writer = pq.ParquetWriter(
                path,
                schema,
                compression="zstd",
            )
            self._writers[path] = writer
            self._schemas[path] = schema

        table = pa.Table.from_batches(batches, schema=self._schemas[path])
        writer.write_table(table)

    def close(self) -> None:
        for writer in self._writers.values():
            writer.close()
        self._writers.clear()


import pyarrow.compute as pc


class ConvertEngine:
    ORDER_TYPES = ["A", "D", "M"]
    TRADE_TYPE = "T"
    TICK_COL = "TickType"

    def __init__(self):
        self.extractor = ExtractorEngine
        self.order_set = pa.array(self.ORDER_TYPES)
        self.trade_value = pa.scalar(self.TRADE_TYPE)

    def convert(self, zfile: Path, out_files: dict[str, Path]) -> None:
        reader = self.extractor.open_reader(zfile)
        writer = ParquetAppendWriter()
        try:
            for batch in reader:
                batch = self.extractor.cast_strings(batch)

                if len(out_files) == 1:
                    # 非拆分
                    key = next(iter(out_files))
                    writer.write_batches(out_files[key], [batch])
                else:
                    # 拆分
                    for key, sub_batch in self._split(batch, out_files).items():
                        if sub_batch.num_rows:
                            writer.write_batches(out_files[key], [sub_batch])
        finally:
            writer.close()

    def _split(self, batch: pa.RecordBatch, out_files: dict[str, Path]) -> dict[str, pa.RecordBatch]:
        """返回 (order_batch, trade_batch)"""
        if self.TICK_COL not in batch.schema.names:
            raise ValueError(f"missing column: {self.TICK_COL}")

        idx = batch.schema.get_field_index(self.TICK_COL)
        tick_arr = batch.column(idx)

        order_mask = pc.is_in(tick_arr, self.order_set)
        trade_mask = pc.equal(tick_arr, self.trade_value)

        result = {}
        for key in out_files:
            if "order" in key:
                result[key] = batch.filter(order_mask)
            elif "trade" in key:
                result[key] = batch.filter(trade_mask)

        return result


In [4]:


from src import DateTimeUtils
from functools import reduce


class NormalizeEngine:
    """
    NormalizeEngine（冻结契约版）

    - 输入：交易所级 parquet
    - 输出：canonical order / trade parquet
    - symbol 只是字段，不做拆分
    """

    VALID_EVENTS = {"ADD", "CANCEL", "TRADE"}
    batch_size = 1_000_0000

    def execute(self, input_file: Path, output_dir: Path) -> None:
        exchange, kind = input_file.stem.split("_", 1)
        out_path = output_dir / input_file.name

        pf = pq.ParquetFile(input_file)
        writer = None

        for batch in pf.iter_batches(self.batch_size):
            table = pa.Table.from_batches([batch])
            table = self.filter_a_share_arrow(table)
            if table.num_rows == 0:
                continue

            table = parse_events_arrow(
                table,
                exchange=exchange,
                kind=kind,
            )

            if table.num_rows == 0:
                continue
            if writer is None:
                writer = pq.ParquetWriter(out_path, table.schema)

            writer.write_table(table)

        if writer:
            writer.close()

    def filter_a_share_arrow(self, table: pa.Table) -> pa.Table:
        symbol = pc.cast(table["SecurityID"], pa.string())

        # prefixes = [
        #     "600", "601", "603", "605", "688",
        #     "000", "001", "002", "003", "300",
        # ]
        prefixes = [
            "60", "688",
            "00", "300",
        ]

        masks = [pc.starts_with(symbol, p) for p in prefixes]

        mask = reduce(pc.or_, masks)

        return table.filter(mask)


class NormalizeStep(BasePipelineStep):
    def __init__(self, engine: NormalizeEngine, inst=None):
        super().__init__(inst)
        self.engine = engine

    def run(self, ctx: PipelineContext) -> PipelineContext:
        input_dir: Path = ctx.parquet_dir
        output_dir: Path = ctx.canonical_dir

        for file in list(input_dir.glob("*.parquet")):
            filename = file.stem
            output_file = output_dir / filename

            if output_file.exists():
                logs.info(f'')
                continue
            self.engine.execute(
                input_file=input_dir / file,
                output_dir=output_dir,
            )


from datetime import datetime

# =============================================================================
# Internal Event Schema（唯一真相）
# =============================================================================
# from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Literal

import pyarrow as pa
import pyarrow.compute as pc

EventKind = Literal["order", "trade"]

INTERNAL_SCHEMA = pa.schema(
    [('symbol', pa.string()),
     ("ts", pa.int64()),
     ("event", pa.string()),
     ("order_id", pa.int64()),
     ("side", pa.string()),
     ("price", pa.float64()),
     ("volume", pa.int64()),
     ("buy_no", pa.int64()),
     ("sell_no", pa.int64()),
     ]
)


@dataclass(frozen=True)
class ExchangeDefinition:
    symbol_field: str
    time_field: str
    event_field: str
    event_mapping: Dict
    price_field: str
    volume_field: str
    side_field: Optional[str]
    side_mapping: Optional[Dict]
    id_field: str
    buy_no_field: Optional[str]
    sell_no_field: Optional[str]


EXCHANGE_REGISTRY = {
    # 上海
    'sh': {
        "order": ExchangeDefinition(
            symbol_field='SecurityID',
            time_field="TickTime",
            event_field="TickType",
            event_mapping={"A": "ADD", "D": "CANCEL"},
            price_field="Price",
            volume_field="Volume",
            side_field="Side",
            side_mapping={'1': "B", '2': "S"},
            id_field="SubSeq",
            buy_no_field=None,
            sell_no_field=None,
        ),
        "trade": ExchangeDefinition(
            symbol_field='SecurityID',
            time_field="TickTime",
            event_field="TickType",
            event_mapping={"T": "TRADE"},
            price_field="Price",
            volume_field="Volume",
            side_field="Side",
            side_mapping={'1': "B", '2': "S"},
            id_field="SubSeq",
            buy_no_field="BuyNo",
            sell_no_field="SellNo",
        ),
    },

    # 深圳
    'sz': {
        "order": ExchangeDefinition(
            symbol_field='SecurityID',
            time_field="OrderTime",
            event_field="OrderType",
            event_mapping={'0': "CANCEL", '1': "ADD", '2': "ADD", '3': "ADD"},
            price_field="Price",
            volume_field="Volume",
            side_field="Side",
            side_mapping={'1': "B", '2': "S"},
            id_field="SubSeq",
            buy_no_field=None,
            sell_no_field=None,
        ),

        "trade": ExchangeDefinition(
            symbol_field='SecurityID',
            time_field="TickTime",
            event_field="ExecType",
            event_mapping={'1': "TRADE", '2': "CANCEL"},
            price_field="TradePrice",
            volume_field="TradeVolume",
            side_field=None,
            side_mapping=None,
            id_field="SubSeq",
            buy_no_field="BuyNo",
            sell_no_field="SellNo",
        ),
    },
}


# MAPPING_kind = {
#     '1':'order',
#     '2':'trade',
# }

# =============================================================================
# 2. TickTime -> offset_us （执行层：Arrow vectorized）
# =============================================================================


def trade_time_to_base_us(trade_time) -> int:
    """
    使用 DateTimeUtils 作为唯一语义来源
    """
    d = DateTimeUtils.extract_date(trade_time)

    base_dt = datetime(
        d.year,
        d.month,
        d.day,
        tzinfo=DateTimeUtils.SH_TZ,
    )
    return int(base_dt.timestamp() * 1_000_000)


def _mod(a: pa.Array, b: int) -> pa.Array:
    """
    Arrow-safe modulo, version independent:
        a % b == a - floor(a / b) * b
    """
    return pc.subtract(
        a,
        pc.multiply(
            pc.cast(pc.floor(pc.divide(a, b)), pa.int64()),
            pa.scalar(b, pa.int64()),
        ),
    )


def tick_to_offset_us(col: pa.Array) -> pa.Array:
    t = pc.cast(col, pa.int64())

    # HH
    hh = pc.cast(pc.floor(pc.divide(t, 1_000_000)), pa.int64())

    # MM
    mm_all = pc.cast(pc.floor(pc.divide(t, 10_000)), pa.int64())
    mm = _mod(mm_all, 100)

    # SS
    ss_all = pc.cast(pc.floor(pc.divide(t, 100)), pa.int64())
    ss = _mod(ss_all, 100)

    # mmm (milliseconds)
    ms = _mod(t, 1_000)

    return pc.add(
        pc.add(
            pc.add(
                pc.multiply(hh, pa.scalar(3_600_000_000, pa.int64())),
                pc.multiply(mm, pa.scalar(60_000_000, pa.int64())),
            ),
            pc.multiply(ss, pa.scalar(1_000_000, pa.int64())),
        ),
        pc.multiply(ms, pa.scalar(1_000, pa.int64())),
    )


def map_dict(col: pa.Array, mapping: dict) -> pa.Array:
    keys = pa.array(list(mapping.keys()))
    vals = pa.array(list(mapping.values()))
    idx = pc.index_in(col, keys)
    return pc.take(vals, idx)


def zeros(n: int) -> pa.Array:
    return pa.array([0] * n, type=pa.int64())


def parse_events_arrow(
        table: pa.Table,
        kind: Literal["order", "trade"] = '',
        exchange: str = ''
) -> pa.Table:
    """
    输入：
        Arrow Table（单 symbol / 单 kind / 单 exchange）
    输出：
        Arrow Table（InternalEvent schema）
    """
    if table.num_rows == 0:
        return pa.Table.from_arrays([])

    try:
        definition = EXCHANGE_REGISTRY[exchange][kind]
    except KeyError:
        raise KeyError(f"No registry for exchange={exchange}, kind={kind}")
    # # ---------------------------------------------------------------------
    #     # ts
    #     # ---------------------------------------------------------------------
    # # print(table["TradeTime"][0])
    base_us = trade_time_to_base_us(table["TradeTime"][0].as_py())
    offset_us = tick_to_offset_us(table[definition.time_field])  # Array
    ts = pc.add(offset_us, pa.scalar(base_us, pa.int64()))  # Array

    # ---------------------------------------------------------------------
    # event
    # ---------------------------------------------------------------------
    event = map_dict(table[definition.event_field], definition.event_mapping)

    # ---------------------------------------------------------------------
    # side
    # ---------------------------------------------------------------------
    if definition.side_field and definition.side_mapping:
        side = map_dict(table[definition.side_field], definition.side_mapping)
    else:
        side = pa.nulls(table.num_rows)
    #
    # ---------------------------------------------------------------------
    # buy / sell no
    # ---------------------------------------------------------------------
    buy_no = (
        table[definition.buy_no_field]
        if definition.buy_no_field
        else zeros(table.num_rows)
    )
    sell_no = (
        table[definition.sell_no_field]
        if definition.sell_no_field
        else zeros(table.num_rows)
    )
    out = pa.table(
        {"symbol": pc.cast(table[definition.symbol_field], pa.string()),
         "ts": ts,
         "event": event,
         "order_id": pc.cast(table[definition.id_field], pa.int64()),
         "side": side,
         "price": pc.cast(table[definition.price_field], pa.float64()),
         "volume": pc.cast(table[definition.volume_field], pa.int64()),
         "buy_no": pc.cast(buy_no, pa.int64()),
         "sell_no": pc.cast(sell_no, pa.int64()),
         }
    )

    return out.cast(INTERNAL_SCHEMA)




In [5]:
#!filepath: src/engines/symbol_split_engine.py
from __future__ import annotations

import pyarrow as pa
from typing import Iterable


class SymbolSplitEngine:
    """
    SymbolSplitEngine（纯逻辑）：

    Input:
        - canonical Events.parquet（Arrow Table / Reader）
        - symbol: str

    Output:
        - bytes（该 symbol 的 parquet 内容）

    约束：
        - 不做 IO
        - 不依赖 Path
        - 不接触 Meta
    """

    def __init__(self, symbol_field: str = "symbol"):
        self.symbol_field = symbol_field

    # --------------------------------------------------
    def split_one(
            self,
            table: pa.Table,
            symbol: str,
    ) -> bytes:
        """
        从 canonical table 中切出某一个 symbol
        """
        mask = pa.compute.equal(table[self.symbol_field], symbol)
        sub = table.filter(mask)

        sink = pa.BufferOutputStream()
        pq.write_table(sub, sink)

        return sink.getvalue().to_pybytes()

    # --------------------------------------------------
    def split_many(
            self,
            table: pa.Table,
            symbols: Iterable[str],
    ) -> dict[str, bytes]:
        """
        一次切多个 symbol（可选优化）
        """
        result: dict[str, bytes] = {}

        for sym in symbols:
            result[sym] = self.split_one(table, sym)

        return result


In [12]:
#!filepath: src/steps/symbol_split_step.py
from __future__ import annotations

import pyarrow.parquet as pq

from src.pipeline.step import PipelineStep
from src.pipeline.meta import MetaRegistry


class SymbolSplitStep(PipelineStep):
    """
    SymbolSplitStep（Meta-aware，冻结版）

    Semantic:
        canonical Events.parquet
            → symbol/{symbol}/{date}/Trade.parquet
    SymbolSplitStep — DAILY-CLOSED (data-driven) FINAL VERSION

    Semantics (FROZEN):

    - Meta is DATE-scoped.
    - Daily universe is defined ONLY by that day's meta.outputs.
    - First run (no meta):
        * Read canonical once
        * Discover symbols appearing on THIS date
        * Full split
        * Write meta (universe = discovered symbols)
    - Subsequent runs:
        * Universe = meta.outputs.keys()
        * If all outputs valid -> SKIP (NO canonical IO)
        * If some outputs invalid/missing -> read canonical and repair ONLY those symbols
    - Does NOT detect symbols missing due to upstream canonical issues.
    """

    def __init__(
            self,
            engine: SymbolSplitEngine,
            inst=None,
    ):
        self.engine = engine
        self.inst = inst

    # --------------------------------------------------
    def run(self, ctx):
        input_dir: Path = ctx.canonical_dir
        output_dir: Path = ctx.symbol_dir

        meta_dir: Path = ctx.meta_dir

        outputs = {}
        for file in list(input_dir.glob("*.parquet")):
            # ① 修正 step 语义：pipeline step + file
            step_key = f"{self.__class__.__name__}:{file.stem}"

            meta = MetaRegistry(
                meta_file=meta_dir / step_key,
                step=file.stem,
                date=ctx.date,
                engine_version="v1",
                input_file=input_dir,
            )
            manifest = meta.load()

            # ---------------------------------------------
            # ① 决定需要 split 的 symbol
            # ---------------------------------------------
            if manifest is None or meta.is_input_changed():
                table = pq.read_table(file, columns=["symbol"])
                symbols = table["symbol"].unique().to_pylist()
            else:
                status = meta.validate_outputs()
                symbols = [k for k, ok in status.items() if not ok]

            if not symbols:
                continue
            # ② 读取 canonical table（一次）
            table = pq.read_table(file)
            # ③ 执行 split（纯逻辑）
            payloads = self.engine.split_many(table, symbols)

            # ④ 写文件 + 记录 meta
            meta.begin_new()

            for sym, data in payloads.items():
                out_file = output_dir / sym / file.name.split('_')[1]
                FileSystem.safe_write(out_file, data)
                meta.record_output(sym, out_file)

            meta.commit()











In [13]:

cs = CsvConvertStep(engine=ConvertEngine(), inst=inst)
cs.run(ctx)

ns = NormalizeStep(engine=NormalizeEngine(), inst=inst)
ns.run(ctx)

sp = SymbolSplitStep(engine=SymbolSplitEngine(), inst=inst)
sp.run(ctx)



[CsvConvertStep]  SZ_Order.csv.7z {'sz_order': PosixPath('/home/wsw/data/parquet/2015-01-01/sz_order.parquet')}
[CsvConvertStep]  SH_Stock_OrderTrade.csv.7z {'sh_order': PosixPath('/home/wsw/data/parquet/2015-01-01/sh_order.parquet'), 'sh_trade': PosixPath('/home/wsw/data/parquet/2015-01-01/sh_trade.parquet')}
[CsvConvertStep]  SZ_Trade.csv.7z {'sz_trade': PosixPath('/home/wsw/data/parquet/2015-01-01/sz_trade.parquet')}


FileNotFoundError: [Errno 2] Failed to open local file '/home/wsw/data/canonical/2015-01-01/sh_order.parquet'. Detail: [errno 2] No such file or directory

In [14]:
t = pq.read_table('/home/wsw/data/canonical/2015-01-01/sh_order.parquet')

In [None]:
t.shape

In [None]:
t.slice(0, 5)

In [None]:
t.take([0, 1, 2])

In [None]:
import json

In [None]:
with open('/home/wsw/data/meta/2015-01-01/SymbolSplitStep:sh_order.parquet.json') as f:
    details = json.load(f)

In [None]:
details

In [None]:
with open('/home/wsw/data/meta/2015-01-01/SymbolSplitStep:sh_trade.parquet.json') as f:
    details2 = json.load(f)
details2