In [2]:
from model import Kronos, KronosTokenizer, KronosPredictor

# Load from Hugging Face Hub
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-base")

# Initialize the predictor
predictor = KronosPredictor(model, tokenizer, device="mps", max_context=512)


In [3]:
from matplotlib import pyplot as plt


def plot_prediction(kline_df, pred_df):
    pred_df.index = kline_df.index[-pred_df.shape[0] :]
    sr_close = kline_df["close"]
    sr_pred_close = pred_df["close"]
    sr_close.name = "Ground Truth"
    sr_pred_close.name = "Prediction"

    sr_volume = kline_df["volume"]
    sr_pred_volume = pred_df["volume"]
    sr_volume.name = "Ground Truth"
    sr_pred_volume.name = "Prediction"

    close_df = pd.concat([sr_close, sr_pred_close], axis=1)
    volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

    ax1.plot(
        close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
    )
    ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
    ax1.set_ylabel("Close Price", fontsize=14)
    ax1.legend(loc="lower left", fontsize=12)
    ax1.grid(True)

    ax2.plot(
        volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5
    )
    ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
    ax2.set_ylabel("Volume", fontsize=14)
    ax2.legend(loc="upper left", fontsize=12)
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

In [6]:
import pandas as pd
stock_data = pd.read_csv("./examples/data/XSHG_5min_600977.csv")

In [7]:
stock_data.tail()

Unnamed: 0,timestamps,open,high,low,close,volume,amount
2495,2024-08-29 11:10:00,9.87,9.87,9.85,9.85,549.0,541430.0
2496,2024-08-29 11:15:00,9.86,9.87,9.85,9.87,751.0,740963.0
2497,2024-08-29 11:20:00,9.86,9.87,9.85,9.87,437.0,430959.0
2498,2024-08-29 11:25:00,9.86,9.89,9.86,9.86,625.0,617074.0
2499,2024-08-29 11:30:00,9.87,9.89,9.87,9.88,349.0,344735.0


In [8]:
selected_data = stock_data[["open", "high", "low", "close", "volume","amount"]].reset_index(drop=True)

In [9]:
len(selected_data)

2500

In [12]:
x_timestamp = stock_data["timestamps"].reset_index(drop=True)
x_timestamp = x_timestamp[len(stock_data) - 512 :]
x_timestamp = pd.to_datetime(x_timestamp)

In [13]:
len(x_timestamp)

512

In [14]:
dates = pd.bdate_range(start="2025-09-23", end="2025-09-24")


# 将 DatetimeIndex 转换为 Series

y_timestamp_series = pd.Series(dates)

# 查看转换后的 Series
print(y_timestamp_series)

0   2025-09-23
1   2025-09-24
dtype: datetime64[ns]


In [15]:
# Generate predictions
pred_df = predictor.predict(
    df=selected_data,
    x_timestamp=x_timestamp,
    y_timestamp=y_timestamp_series,
    pred_len=len(y_timestamp_series),
    T=1.0,  # Temperature for sampling
    top_p=0.9,  # Nucleus sampling probability
    sample_count=1,  # Number of forecast paths to generate and average
)

print("Forecasted Data Head:")
print(pred_df.head())

100%|██████████| 2/2 [00:04<00:00,  2.01s/it]


Forecasted Data Head:
                open      high       low     close      volume       amount
2025-09-23  9.816091  9.847658  9.833866  9.863901  678.991760  717339.0625
2025-09-24  9.858664  9.849227  9.855570  9.845477  523.683533  560401.7500


In [17]:
pred_df["date"] = pred_df.index

In [20]:
pred_df.to_csv("./a.csv")

In [74]:
import baostock as bs
import pandas as pd

#### 登陆系统 ####
lg = bs.login()
# 显示登陆返回信息
print("login respond error_code:" + lg.error_code)
print("login respond  error_msg:" + lg.error_msg)

#### 获取交易日信息 ####
rs = bs.query_trade_dates(start_date="2025-09-23", end_date="2026-09-30")
print("query_trade_dates respond error_code:" + rs.error_code)
print("query_trade_dates respond  error_msg:" + rs.error_msg)

#### 打印结果集 ####
data_list = []
while (rs.error_code == "0") & rs.next():
    # 获取一条记录，将记录合并在一起
    data_list.append(rs.get_row_data())
result = pd.DataFrame(data_list, columns=rs.fields)

#### 结果集输出到csv文件 ####

print(result)

#### 登出系统 ####
bs.logout()

login success!
login respond error_code:0
login respond  error_msg:success
query_trade_dates respond error_code:0
query_trade_dates respond  error_msg:success
   calendar_date is_trading_day
0     2025-09-23              1
1     2025-09-24              1
2     2025-09-25              1
3     2025-09-26              1
4     2025-09-27              0
..           ...            ...
95    2025-12-27              0
96    2025-12-28              0
97    2025-12-29              1
98    2025-12-30              1
99    2025-12-31              1

[100 rows x 2 columns]
logout success!


<baostock.data.resultset.ResultData at 0x106dd2150>

In [3]:
import baostock as bs
import pandas as pd

#### 登陆系统 ####
lg = bs.login()
# 显示登陆返回信息
print("login respond error_code:" + lg.error_code)
print("login respond  error_msg:" + lg.error_msg)

#### 获取某日所有证券信息 ####
rs = bs.query_all_stock(day="2024-10-25")
print("query_all_stock respond error_code:" + rs.error_code)
print("query_all_stock respond  error_msg:" + rs.error_msg)

#### 打印结果集 ####
data_list = []
while (rs.error_code == "0") & rs.next():
    # 获取一条记录，将记录合并在一起
    data_list.append(rs.get_row_data())
result = pd.DataFrame(data_list, columns=rs.fields)

#### 结果集输出到csv文件 ####
result.to_csv("D:\\all_stock.csv", encoding="utf-8", index=False)
print(result)

#### 登出系统 ####
bs.logout()

login success!
login respond error_code:0
login respond  error_msg:success
query_all_stock respond error_code:0
query_all_stock respond  error_msg:success
           code tradeStatus   code_name
0     sh.000001           1      上证综合指数
1     sh.000002           1      上证A股指数
2     sh.000003           1      上证B股指数
3     sh.000004           1     上证工业类指数
4     sh.000005           1     上证商业类指数
...         ...         ...         ...
5641  sz.399994           1  中证信息安全主题指数
5642  sz.399995           1    中证基建工程指数
5643  sz.399996           1    中证智能家居指数
5644  sz.399997           1      中证白酒指数
5645  sz.399998           1      中证煤炭指数

[5646 rows x 3 columns]
logout success!


<baostock.data.resultset.ResultData at 0x11e15fef0>

In [None]:
for idx,data in result.iterrows():
    print(data.)

sh.000001
sh.000002
sh.000003
sh.000004
sh.000005
sh.000006
sh.000007
sh.000008
sh.000009
sh.000010
sh.000011
sh.000012
sh.000013
sh.000015
sh.000016
sh.000017
sh.000018
sh.000019
sh.000020
sh.000021
sh.000022
sh.000025
sh.000026
sh.000027
sh.000028
sh.000029
sh.000030
sh.000031
sh.000032
sh.000033
sh.000034
sh.000035
sh.000036
sh.000037
sh.000038
sh.000039
sh.000040
sh.000041
sh.000042
sh.000043
sh.000044
sh.000045
sh.000046
sh.000047
sh.000048
sh.000049
sh.000050
sh.000051
sh.000052
sh.000053
sh.000054
sh.000056
sh.000057
sh.000058
sh.000059
sh.000060
sh.000061
sh.000062
sh.000063
sh.000064
sh.000065
sh.000066
sh.000067
sh.000068
sh.000069
sh.000070
sh.000071
sh.000072
sh.000073
sh.000074
sh.000075
sh.000076
sh.000077
sh.000078
sh.000079
sh.000090
sh.000091
sh.000092
sh.000093
sh.000094
sh.000095
sh.000096
sh.000097
sh.000098
sh.000099
sh.000100
sh.000101
sh.000102
sh.000103
sh.000104
sh.000105
sh.000106
sh.000107
sh.000108
sh.000109
sh.000110
sh.000111
sh.000112
sh.000113
sh.000114
