In [2]:
from datetime import datetime, timedelta

import torch
import torch.nn.functional as F
import baostock as bs
import akshare as ak
import pandas as pd

- 手续费：1.3e-4
- 印花税：0
- 滑点：0.01

- 实际花费：(开盘价+滑点) * (1+手续费)
- 实际收入：(收盘价-滑点)*(1-手续费-印花税)

In [3]:
pe_df = pd.read_csv("../data/000300_pe.csv", encoding="gbk")
index_df = pd.read_csv("../data/000300_price.csv", encoding="gbk")
bond_df = pd.read_csv("../data/10yr_bond_yield.csv", encoding="gbk")
bond_df.head()

Unnamed: 0,tradeDate,curveName,maturity,curveType,yield
0,2025/4/3,??????§?????,10,????,1.785
1,2025/4/2,??????§?????,10,????,1.86
2,2025/4/1,??????§?????,10,????,1.874
3,2025/3/31,??????§?????,10,????,1.876
4,2025/3/28,??????§?????,10,????,1.884


In [4]:
pe_df["tradeDate"] = pd.to_datetime(pe_df["tradeDate"])
index_df["tradeDate"] = pd.to_datetime(index_df["tradeDate"])
bond_df["tradeDate"] = pd.to_datetime(bond_df["tradeDate"])

In [5]:
index_df = index_df.merge(pe_df, how="left", on="tradeDate")

In [6]:
index_df = index_df[index_df["tradeDate"] >= "2013-01-01"]
len(index_df)

2974

In [7]:
bond_df = bond_df[bond_df["tradeDate"] >= "2013-01-01"]
len(bond_df)

3031

In [8]:
overall_df = index_df.merge(bond_df, how="left", on="tradeDate")

- 总长度2048
- prompt长度5

In [15]:
from typing import List

import torch
import torch.nn.functional as F
import pandas as pd
from datasets import Dataset


class PriceProcessor:

    def __init__(
        self,
        df: pd.DataFrame,
        input_cols: List[str],
        reward_cols: List[str],
        date_col: str = "tradeDate",
        prompt_len: int = 5,
        max_len: int = 2048,
    ):
        cols = list(set(input_cols + reward_cols))
        self.df = df[[date_col] + cols].sort_values(
            by=date_col, ignore_index=True).bfill().ffill()
        self.input_cols = input_cols
        self.reward_cols = reward_cols
        self.date_col = date_col
        self.prompt_len = prompt_len
        self.max_len = max_len

    def __call__(self, start_time: str):

        if isinstance(start_time, str):
            start_time = pd.to_datetime(start_time)
        window_df = self.df.loc[self.df[self.date_col] >=
                                start_time].iloc[:self.max_len]
        # window_df[self.date_col] = window_df[self.date_col].apply(
        #     lambda x: x.timestamp())

        input_ids = torch.tensor(window_df[self.input_cols].values,
                                 dtype=torch.float32)

        # pad
        prefix_fill = max(0, self.prompt_len - input_ids.size(0))
        suffix_fill = max(0, self.max_len - input_ids.size(0) - prefix_fill)
        pad_dim = (0, 0, prefix_fill, suffix_fill)
        input_ids = F.pad(input_ids, pad_dim, value=0)

        # mask
        attention_mask = torch.ones(input_ids.shape[:-1], dtype=torch.float32)
        if prefix_fill:
            attention_mask[:prefix_fill] = 0
        if suffix_fill:
            attention_mask[-suffix_fill:] = 0

        # split
        prompt_ids, prompt_mask = input_ids[:self.
                                            prompt_len], attention_mask[:self.
                                                                        prompt_len]
        completion_ids, completion_mask = input_ids[
            self.prompt_len:], attention_mask[self.prompt_len:]

        result = {
            # "trade_date": window_df[self.date_col].values,
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
        }

        # prepare inputs from reward_func
        window_dict = window_df[self.reward_cols].to_dict(orient="list")
        window_dict = {
            k: F.pad(torch.tensor(v, dtype=torch.float32),
                     pad_dim[2:],
                     value=0)
            for k, v in window_dict.items()
        }
        result.update(window_dict)

        return result

    def rolling(self, **kwargs):

        windows = list(self.df.rolling(**kwargs))
        print(sorted(map(len, windows)))
        windows = list(filter(lambda x: len(x) >= self.prompt_len, windows))

        results = [self(win.iloc[0][self.date_col]) for win in windows]
        dataset = Dataset.from_list(results)

        return dataset

In [16]:
pp = PriceProcessor(overall_df, ["EPValue", "yield"],
                    ["openIndex", "closeIndex"])

In [17]:
dataset = pp.rolling(window=2048, step=13)

[1, 14, 27, 40, 53, 66, 79, 92, 105, 118, 131, 144, 157, 170, 183, 196, 209, 222, 235, 248, 261, 274, 287, 300, 313, 326, 339, 352, 365, 378, 391, 404, 417, 430, 443, 456, 469, 482, 495, 508, 521, 534, 547, 560, 573, 586, 599, 612, 625, 638, 651, 664, 677, 690, 703, 716, 729, 742, 755, 768, 781, 794, 807, 820, 833, 846, 859, 872, 885, 898, 911, 924, 937, 950, 963, 976, 989, 1002, 1015, 1028, 1041, 1054, 1067, 1080, 1093, 1106, 1119, 1132, 1145, 1158, 1171, 1184, 1197, 1210, 1223, 1236, 1249, 1262, 1275, 1288, 1301, 1314, 1327, 1340, 1353, 1366, 1379, 1392, 1405, 1418, 1431, 1444, 1457, 1470, 1483, 1496, 1509, 1522, 1535, 1548, 1561, 1574, 1587, 1600, 1613, 1626, 1639, 1652, 1665, 1678, 1691, 1704, 1717, 1730, 1743, 1756, 1769, 1782, 1795, 1808, 1821, 1834, 1847, 1860, 1873, 1886, 1899, 1912, 1925, 1938, 1951, 1964, 1977, 1990, 2003, 2016, 2029, 2042, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048,

In [22]:
dataset.set_format("pt")

In [25]:
dataset["completion_ids"].shape

torch.Size([228, 2043, 2])

In [53]:
data = pp(start_time="2025-03-23", prompt_len=5, max_len=10)
data

{'prompt_ids': tensor([[0.0791, 1.9090],
         [0.0791, 1.8850],
         [0.0795, 1.8790],
         [0.0793, 1.8720],
         [0.0800, 1.8840]]),
 'prompt_mask': tensor([1., 1., 1., 1., 1.]),
 'completion_ids': tensor([[0.0805, 1.8760],
         [0.0801, 1.8740],
         [0.0801, 1.8600],
         [0.0803, 1.7850],
         [0.0000, 0.0000]]),
 'completion_mask': tensor([1., 1., 1., 1., 0.]),
 'openIndex': tensor([3916.1135, 3937.8975, 3930.2073, 3911.3840, 3930.7764, 3905.9814,
         3892.7793, 3884.4182, 3843.4177,    0.0000]),
 'closeIndex': tensor([3934.8486, 3932.2952, 3919.3567, 3932.4116, 3915.1663, 3887.3057,
         3887.6841, 3884.3857, 3861.5034,    0.0000])}

In [33]:
data["action_mask"][:4] = 0

In [34]:
data["action_mask"]

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])

In [87]:
pd.to_datetime("2025/1/1 12:01:56").timestamp()

1735732916.0

In [88]:
timedelta()

datetime.datetime(2025, 1, 1, 20, 1, 56)

In [6]:
pe_df.iloc[-1]

tradeDate          2008/5/5
secID           000300.ZICN
ticker                  300
secShortName        ????300
PEValue               27.97
PB                     4.49
ROE                 0.15996
EPValue            0.035753
Name: 4112, dtype: object

In [34]:
lg = bs.login()

login success!


In [11]:
rs = bs.query_history_k_data_plus(
    "sh.600000",
    "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST",
    start_date='2024-07-01',
    end_date='2024-12-31',
    frequency="d",
    adjustflag="3")

In [37]:
rs = bs.query_history_k_data_plus(
    "sh.000300",
    "date,code,open,high,low,close,preclose,volume,amount,pctChg",
    start_date='2009-01-01',
    end_date='2025-04-04',
    frequency="d")

In [38]:
rs.get_data()

Unnamed: 0,date,code,open,high,low,close,preclose,volume,amount,pctChg
0,2009-01-05,sh.000300,1848.3260,1882.9590,1837.8390,1882.9590,1817.7220,4818702000,39217076790.0000,3.588939
1,2009-01-06,sh.000300,1880.6720,1948.4890,1873.0110,1942.7950,1882.9590,7097409700,59262170581.0000,3.177768
2,2009-01-07,sh.000300,1942.6710,1959.2460,1930.8680,1931.1780,1942.7950,5970216700,50143418262.0000,-0.597956
3,2009-01-08,sh.000300,1894.6640,1902.7960,1873.6480,1887.9910,1931.1780,5344649000,47210802398.0000,-2.236304
4,2009-01-09,sh.000300,1886.4880,1923.4230,1886.4880,1918.3650,1887.9910,4740721600,41511211439.0000,1.608801
...,...,...,...,...,...,...,...,...,...,...
3942,2025-03-28,sh.000300,3930.7763,3934.2336,3907.3840,3915.1662,3932.4117,12635621400,213379194289.3000,-0.438548
3943,2025-03-31,sh.000300,3905.9815,3928.4482,3872.4030,3887.3056,3915.1662,16679852300,268480316572.3000,-0.711607
3944,2025-04-01,sh.000300,3892.7794,3905.2622,3882.4045,3887.6841,3887.3056,14038027400,246041674670.2000,0.009737
3945,2025-04-02,sh.000300,3884.4181,3900.4949,3877.1599,3884.3858,3887.6841,11304317500,198267515342.4000,-0.084840


In [47]:
bonds = ak.bond_china_yield(start_date="20240407", end_date="20250407")

In [53]:
bonds.columns

Index(['曲线名称', '日期', '3月', '6月', '1年', '3年', '5年', '7年', '10年', '30年'], dtype='object')

In [58]:
bonds[["曲线名称", '日期', "10年"]][bonds["曲线名称"].str.contains("国债收益率")].tail(10)

Unnamed: 0,曲线名称,日期,10年
722,中债国债收益率曲线,2025-03-24,1.8413
725,中债国债收益率曲线,2025-03-25,1.8183
726,中债国债收益率曲线,2025-03-26,1.7943
731,中债国债收益率曲线,2025-03-27,1.8062
734,中债国债收益率曲线,2025-03-28,1.8126
737,中债国债收益率曲线,2025-03-31,1.8129
738,中债国债收益率曲线,2025-04-01,1.8104
743,中债国债收益率曲线,2025-04-02,1.7887
746,中债国债收益率曲线,2025-04-03,1.718
749,中债国债收益率曲线,2025-04-07,1.6318


In [56]:
bond_df.head()

Unnamed: 0,tradeDate,curveName,maturity,curveType,yield
0,2025-04-03,??????§?????,10,????,1.785
1,2025-04-02,??????§?????,10,????,1.86
2,2025-04-01,??????§?????,10,????,1.874
3,2025-03-31,??????§?????,10,????,1.876
4,2025-03-28,??????§?????,10,????,1.884


In [22]:
stock_zh_index_daily_df = ak.stock_zh_index_daily(symbol="sh000300")

In [23]:
stock_zh_index_daily_df

Unnamed: 0,date,open,high,low,close,volume
0,2002-01-04,1316.455,1316.455,1316.455,1316.455,0
1,2002-01-07,1302.084,1302.084,1302.084,1302.084,0
2,2002-01-08,1292.714,1292.714,1292.714,1292.714,0
3,2002-01-09,1272.645,1272.645,1272.645,1272.645,0
4,2002-01-10,1281.261,1281.261,1281.261,1281.261,0
...,...,...,...,...,...,...
5634,2025-03-28,3930.776,3934.234,3907.384,3915.166,12635621400
5635,2025-03-31,3905.982,3928.448,3872.403,3887.306,16679852300
5636,2025-04-01,3892.779,3905.262,3882.405,3887.684,14038027400
5637,2025-04-02,3884.418,3900.495,3877.160,3884.386,11304317500


In [None]:
bond_china_yield_df = ak.bond_china_yield(date="2020-02-04")

In [None]:
def compute_yield_policy1_old(
    actions: torch.Tensor,  # (N, S)
    open_price: Sequence[float],  # (S,)
    close_price: Sequence[float],  # (S,)
    slippage: float = 0.01,
    stamps: float = 0.0,
    service_fee: float = 1.3e-4,
    assets: float = 2e5,
    **kwargs,
):

    if actions.dim() == 1:
        actions = actions.unsqueeze(0)

    num_generations, span_len = actions.shape
    zero = torch.zeros((num_generations, 1), device=actions.device)
    diff = torch.diff(actions.eq(1).int(), prepend=zero, append=zero)
    yields = torch.zeros(num_generations).to(actions.device)

    for i in range(num_generations):

        total_assets = assets

        # calculate the holding periods
        start = torch.where(diff[i] == 1)[0]
        end = torch.where(diff[i] == -1)[0]

        for span in list(zip(start.tolist(), end.tolist())):
            start_idx, end_idx = span

            bid_rate = (open_price[start_idx] + slippage) * (1 + service_fee)
            shares = total_assets // bid_rate
            total_assets = total_assets % bid_rate

            if end_idx < span_len:
                ask_rate = (open_price[end_idx] -
                            slippage) * (1 - service_fee - stamps)
            else:
                ask_rate = close_price[-1]

            total_assets += shares * ask_rate

        yields[i] = (total_assets - assets) / assets

    return yields