In [1]:
import os
os.chdir('d:/future/index_future_prediction/Index_Future_Prediction')

In [2]:
import tushare as ts
pro = ts.pro_api('700c1d6015ad030ff20bf310c088243da030e6b79a2a1098d58d2614')

In [14]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import lr_scheduler, Adam, AdamW
from scipy.stats import norm, t

In [25]:
class GetOHLCV():
    def __init__(self):
        pass


    def get_data(self, assets_code, pred_len, threshold_ratio):
        data_1 = pro.fut_daily(ts_code = assets_code, start_date = '20110101', end_date = '20180101')
        data_2 = pro.fut_daily(ts_code = assets_code, start_date = '20180101')

        data = pd.concat([data_1, data_2], ignore_index = True)

        data['oi_chg'] = 1
        data.dropna(inplace=True)
        data.sort_values(by = 'trade_date', inplace = True)
        data['log_open'] = np.log(data['open'] / data['pre_close']) * 100 # 标准化为对数百分比（不含百分号）
        data['log_high'] = np.log(data['high'] / data['pre_close']) * 100 
        data['log_low'] = np.log(data['low'] / data['pre_close']) * 100 
        data['log_close'] = np.log(data['close'] / data['pre_close']) * 100
        data['log_amount'] = np.log(data['amount'] / data['amount'].shift(1)) * 10

        data['label_return'] = data['log_close'].rolling(window = pred_len).sum().shift(-pred_len) # 标准化为对数百分比（不含百分号），可以直接相加

        data['ma_amount'] = data['amount'].rolling(window = 250).mean() # 过去一年的成交量均值
        data['ma_return_std'] = data['label_return'].rolling(window = 250).std()# 过去一年的收益标准差
        
        data['label_std'] = data['amount'].rolling(window = pred_len).mean().shift(-pred_len)/ data['ma_amount'] * data['ma_return_std'] # 根据当前成交量和历史成交量，估计当前隐含的标准差 由于用1年滚动，避免数据泄露

        data['upper_bond'] = data['label_return'].rolling(window = 250).quantile(1 - threshold_ratio) # 过去一年的收益下分位数
        data['lower_bond'] = data['label_return'].rolling(window = 250).quantile(threshold_ratio) # 过去一年的收益上分位数
        data['threshold'] = (abs(data['upper_bond']) + abs(data['lower_bond']))/2 # 过去一年的收益的分割阈值

        def down_probability(row):
            return norm.cdf(-row['threshold'], loc = row['label_return'], scale=row['label_std'])

        def middle_probability(row):
            return norm.cdf(row['threshold'], loc = row['label_return'], scale=row['label_std']) - norm.cdf(-row['threshold'], loc = row['label_return'], scale=row['label_std'])

        def up_probability(row):
            return 1 - norm.cdf(row['threshold'], loc = row['label_return'], scale=row['label_std'])
        
        data['down_prob'] = data.apply(down_probability, axis = 1)
        data['middle_prob'] = data.apply(middle_probability, axis = 1)
        data['up_prob'] = data.apply(up_probability, axis = 1)
        
        data.dropna(inplace=True)

        return data


In [None]:
if __name__ == '__main__':
    GetOHLCV('M.DCE', 5, 0.33)

In [None]:
GetOHLCV().get_data('M.DCE', 5, 0.33)

Unnamed: 0,ts_code,trade_date,pre_close,pre_settle,open,high,low,close,settle,change1,...,label_return,ma_amount,ma_return_std,label_std,upper_bond,lower_bond,threshold,down_prob,middle_prob,up_prob
1452,M.DCE,20120111,2983.0,2981.0,2975.0,2982.0,2965.0,2976.0,2973.0,-5.0,...,-2.828609,1.069966e+06,2.108374,1.417075,0.629300,-1.152392,0.890846,0.914256,0.081408,0.004336
1451,M.DCE,20120112,2976.0,2973.0,2942.0,2948.0,2935.0,2947.0,2941.0,-26.0,...,-1.194757,1.066340e+06,2.108898,1.367034,0.629300,-1.166343,0.897822,0.585978,0.351105,0.062916
1450,M.DCE,20120113,2947.0,2941.0,2921.0,2929.0,2910.0,2913.0,2918.0,-28.0,...,1.228265,1.064587e+06,2.110528,1.242255,0.636482,-1.166343,0.901412,0.043231,0.353001,0.603768
1449,M.DCE,20120116,2913.0,2918.0,2882.0,2886.0,2870.0,2873.0,2878.0,-45.0,...,2.848022,1.062398e+06,2.118528,1.187591,0.636482,-1.166343,0.901412,0.000797,0.049797,0.949406
1448,M.DCE,20120117,2873.0,2878.0,2890.0,2909.0,2883.0,2909.0,2894.0,31.0,...,1.162009,1.059401e+06,2.103195,1.086476,0.636482,-1.166343,0.901412,0.028770,0.376451,0.594779
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1711,M.DCE,20250915,3079.0,3087.0,3080.0,3083.0,3032.0,3042.0,3051.0,-45.0,...,-0.263331,4.023548e+06,2.691513,2.060651,0.897149,-1.082975,0.990062,0.362168,0.366322,0.271510
1710,M.DCE,20250916,3042.0,3051.0,3046.0,3058.0,3030.0,3041.0,3046.0,-10.0,...,-3.786681,4.012557e+06,2.701284,2.622218,0.896266,-1.107239,1.001753,0.855894,0.110189,0.033918
1709,M.DCE,20250917,3041.0,3046.0,3042.0,3051.0,2998.0,3002.0,3019.0,-44.0,...,-2.427631,4.004861e+06,2.704909,2.550656,0.896266,-1.123008,1.009637,0.710872,0.200235,0.088893
1708,M.DCE,20250918,3002.0,3019.0,3008.0,3016.0,2990.0,2993.0,3001.0,-26.0,...,-0.872489,3.993290e+06,2.703824,2.735224,0.896266,-1.107239,1.001753,0.481153,0.272245,0.246601
