In [None]:
import pandas as pd
from datetime import datetime, timedelta
import os

def create_order_book(df, timestamps):
    timestamps = pd.to_datetime(timestamps)
    ts_index = 0

    # Initialize order book and dictionary to store order books for different timestamps
    order_book = pd.DataFrame(columns=['appl_seq_num', 'side', 'price', 'order_qty'])
    # bid_side = pd.DataFrame()
    new_orders = pd.DataFrame(columns=['appl_seq_num', 'side', 'price', 'order_qty'])
    order_books = {ts: None for ts in timestamps}

    # Loop through sorted dataframe
    for idx, row in df.iterrows():
        # If the timestamp of the current row is greater than the current timestamp in the list,
        # finalize the order book for the current timestamp, then move to the next timestamp
        while ts_index < len(timestamps) and row['transact_time'] > timestamps[ts_index]:
            # Concat new orders to the order book
            if not new_orders.empty:
                order_book = pd.concat([order_book, new_orders], ignore_index=True)
                new_orders = new_orders.iloc[0:0]  # Clear new_orders DataFrame

            bid_side = order_book[order_book['side'] == 1].sort_values(by='price', ascending=False).head(1)
            offer_side = order_book[order_book['side'] == 2].sort_values(by='price', ascending=True).head(1)
            order_books[timestamps[ts_index]] = (bid_side, offer_side)
            ts_index += 1

        if row['order_type'] == '2':  # New order
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == '1':
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            if row['side'] == 2:
                new_order['price'] = max(new_orders[new_orders['side'] == 1]['price'].max(), order_book[order_book['side'] == 1]['price'].max())
            else:
                new_order['price'] = min(new_orders[new_orders['side'] == 2]['price'].min(), order_book[order_book['side'] == 2]['price'].min())
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == 'U':
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            if row['side'] == 1:
                new_order['price'] = max(new_orders[new_orders['side'] == 1]['price'].max(), order_book[order_book['side'] == 1]['price'].max())
            else:
                new_order['price'] = min(new_orders[new_orders['side'] == 2]['price'].min(), order_book[order_book['side'] == 2]['price'].min())
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == '4':  # Cancel order
            if row['bid_appl_seq_num'] != 0:
                new_orders = new_orders[new_orders['appl_seq_num'] != row['bid_appl_seq_num']]
                order_book = order_book[order_book['appl_seq_num'] != row['bid_appl_seq_num']]
            if row['offer_appl_seq_num'] != 0:
                new_orders = new_orders[new_orders['appl_seq_num'] != row['offer_appl_seq_num']]
                order_book = order_book[order_book['appl_seq_num'] != row['offer_appl_seq_num']]
        elif row['order_type'] == 'F':  # Execute trade
            new_orders.loc[new_orders['appl_seq_num'] == row['bid_appl_seq_num'], 'order_qty'] -= row['order_qty']
            new_orders.loc[new_orders['appl_seq_num'] == row['offer_appl_seq_num'], 'order_qty'] -= row['order_qty']
            new_orders = new_orders[new_orders['order_qty'] > 0]
            order_book.loc[order_book['appl_seq_num'] == row['bid_appl_seq_num'], 'order_qty'] -= row['order_qty']
            order_book.loc[order_book['appl_seq_num'] == row['offer_appl_seq_num'], 'order_qty'] -= row['order_qty']
            order_book = order_book[order_book['order_qty'] > 0]
        
        if ts_index == len(timestamps):
            break

    # If there are still timestamps left after looping through the dataframe, finalize the order books for those timestamps
    while ts_index < len(timestamps):
        # Concat new orders to the order book
        if not new_orders.empty:
            order_book = pd.concat([order_book, new_orders], ignore_index=True)
            new_orders = new_orders.iloc[0:0]  # Clear new_orders DataFrame

        bid_side = order_book[order_book['side'] == 1].sort_values(by='price', ascending=False).head(1)
        offer_side = order_book[order_book['side'] == 2].sort_values(by='price', ascending=True).head(1)
        order_books[timestamps[ts_index]] = (bid_side, offer_side)
        ts_index += 1

    return order_books


# Set the directory where files are stored
directory = "/workspaces/quant-project/data/sz_level3/000069/"  # Current directory

# Create date range
date_range = pd.date_range(start="2020-01-23", end="2020-07-07")

# Enumerate through each date in the range
for date in date_range:
    # Format date as a string in the form YYYY-MM-DD
    timestamps = []
    date_str = date.strftime("%Y%m%d")

    # Construct the full path of the file
    file_name = "000069_"+ date_str+ ".csv.gz"
    file_path = os.path.join(directory, file_name)

    # Check if a file with this name exists in the directory
    if os.path.isfile(file_path):
        do_flag = True
        # print("File '{file_name}' exists.")
    else:
        # print("File '{file_name}' does not exist.")
        continue

    time_0 = datetime.strptime("09:30:00", '%H:%M:%S').time()
    time_1 = datetime.strptime("11:30:00", '%H:%M:%S').time()
    start_time = datetime.combine(date, time_0)  # Specify your desired start time
    end_time = datetime.combine(date, time_1)  # Specify your desired end time

    # Define the time step as 3 seconds
    time_step = timedelta(seconds=3)

    # Calculate the total number of steps
    num_steps = int((end_time - start_time) / time_step) + 1

    # Generate and print timestamps
    current_time = start_time
    for _ in range(num_steps):
        timestamps.append(current_time.strftime("%Y-%m-%d %H:%M:%S"))
        current_time += time_step

    time_0 = datetime.strptime("13:00:00", '%H:%M:%S').time()
    time_1 = datetime.strptime("15:00:00", '%H:%M:%S').time()
    start_time = datetime.combine(date, time_0)  # Specify your desired start time
    end_time = datetime.combine(date, time_1)  # Specify your desired end time

    # Calculate the total number of steps
    num_steps = int((end_time - start_time) / time_step) + 1

    # Generate and print timestamps
    current_time = start_time
    for _ in range(num_steps):
        timestamps.append(current_time.strftime("%Y-%m-%d %H:%M:%S"))
        current_time += time_step

    df = pd.read_csv(file_path, compression='gzip')
    df['transact_time'] = pd.to_datetime(df['transact_time'], format="%Y%m%d%H%M%S%f")
    df['price'] = df['price'] / 10000  # Data cleaning, restore the price to its real value
    df['order_qty'] = df['order_qty'] / 100  # Data cleaning, restore the order quantity to its real value
    stock_code = '000069'
    df['stock_code'] = stock_code

    order_books = create_order_book(df, timestamps)

    for timestamp, (bid_side, offer_side) in order_books.items():
        print("Order book at", timestamp)
        print("Bid side:")
        print(bid_side)
        print("Offer side:")
        print(offer_side)



In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import re

# Initialize lists for storing data
timestamps = []
bid_prices = []
offer_prices = []
bid_qty = []
offer_qty = []

with open('/workspaces/quant-project/log', 'r') as file:
    lines = file.readlines()
    fake_time = True
    for i, line in enumerate(lines):
        # Find timestamp
        if "Order book at" in line:
            if fake_time:
                fake_time = False
            else:
                timestamps.append(pd.to_datetime(time))
                bid_prices.append(bid_price)
                bid_qty.append(bid_order_qty)
                offer_prices.append(offer_price)
                offer_qty.append(offer_order_qty)
            
            time = re.findall("\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}", line)[0]
            
        # Find bid price and quantity
        if "Bid side:" in line:
            bid_line = re.findall("\d+\.\d+|\d+", lines[i+2])
            # print(bid_line)
            if not bid_line or fake_time:
                fake_time = True
                continue
            bid_price, bid_order_qty = float(bid_line[3]), float(bid_line[4])
            

        # Find offer price and quantity
        if "Offer side:" in line:
            offer_line = re.findall("\d+\.\d+|\d+", lines[i+2])
            if not offer_line or fake_time:
                fake_time = True
                continue
            offer_price, offer_order_qty = float(offer_line[3]), float(offer_line[4])
    
    if fake_time:
        fake_time = False
    else:
        timestamps.append(pd.to_datetime(time))
        bid_prices.append(bid_price)
        bid_qty.append(bid_order_qty)
        offer_prices.append(offer_price)
        offer_qty.append(offer_order_qty)

# Create a dataframe
data = {'Timestamp': timestamps, 
        'Bid Price': bid_prices, 
        'Offer Price': offer_prices, 
        'Bid Quantity': bid_qty, 
        'Offer Quantity': offer_qty}
df = pd.DataFrame(data)

df.to_csv('00069.csv')

# Set Timestamp as index
df = df.set_index('Timestamp')

# Plot the data
plt.figure(figsize=[15,10])
plt.grid(True)
plt.plot(df['Bid Price'], label='Bid Price', linewidth=2, markersize=12)
plt.plot(df['Offer Price'], label='Offer Price', linewidth=2, markersize=12)
plt.xlabel('Timestamp')
plt.ylabel('Price')
plt.title('Bid and Offer Prices Over Time', fontsize=20)
plt.legend(loc=2)
plt.savefig('plot.png')
plt.show()


In [None]:
import pandas as pd

def create_order_book(df, timestamps):
    timestamps = sorted(pd.to_datetime(timestamps))
    ts_index = 0

    # Initialize order book and dictionary to store order books for different timestamps
    order_book = pd.DataFrame(columns=['appl_seq_num', 'side', 'price', 'order_qty'])
    new_orders = pd.DataFrame(columns=['appl_seq_num', 'side', 'price', 'order_qty'])
    order_books = {ts: None for ts in timestamps}

    # Loop through sorted dataframe
    for idx, row in df.iterrows():
        # If the timestamp of the current row is greater than the current timestamp in the list,
        # finalize the order book for the current timestamp, then move to the next timestamp
        while ts_index < len(timestamps) and row['transact_time'] > timestamps[ts_index]:
            # Concat new orders to the order book
            if not new_orders.empty:
                order_book = pd.concat([order_book, new_orders], ignore_index=True)
                new_orders = new_orders.iloc[0:0]  # Clear new_orders DataFrame

            bid_side = order_book[order_book['side'] == 1].sort_values(by='price', ascending=False).head(10)
            offer_side = order_book[order_book['side'] == 2].sort_values(by='price', ascending=True).head(10)
            order_books[timestamps[ts_index]] = (bid_side, offer_side)
            ts_index += 1

        if row['order_type'] == '2':  # New order
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == '1':
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            if row['side'] == 2:
                new_order['price'] = max(new_orders[new_orders['side'] == 1]['price'].max(), bid_side.iloc[0]['price'])
            else:
                new_order['price'] = min(new_orders[new_orders['side'] == 2]['price'].min(), offer_side.iloc[0]['price'])
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == 'U':
            new_order = pd.DataFrame([row[['appl_seq_num', 'side', 'price', 'order_qty']]])
            if row['side'] == 1:
                new_order['price'] = max(new_orders[new_orders['side'] == 1]['price'].max(), bid_side.iloc[0]['price'])
            else:
                new_order['price'] = min(new_orders[new_orders['side'] == 2]['price'].min(), offer_side.iloc[0]['price'])
            new_orders = pd.concat([new_orders, new_order], ignore_index=True)
        elif row['order_type'] == '4':  # Cancel order
            if row['bid_appl_seq_num'] != 0:
                new_orders = new_orders[new_orders['appl_seq_num'] != row['bid_appl_seq_num']]
                order_book = order_book[order_book['appl_seq_num'] != row['bid_appl_seq_num']]
            if row['offer_appl_seq_num'] != 0:
                new_orders = new_orders[new_orders['appl_seq_num'] != row['offer_appl_seq_num']]
                order_book = order_book[order_book['appl_seq_num'] != row['offer_appl_seq_num']]
        elif row['order_type'] == 'F':  # Execute trade
            new_orders.loc[new_orders['appl_seq_num'] == row['bid_appl_seq_num'], 'order_qty'] -= row['order_qty']
            new_orders.loc[new_orders['appl_seq_num'] == row['offer_appl_seq_num'], 'order_qty'] -= row['order_qty']
            new_orders = new_orders[new_orders['order_qty'] > 0]
            order_book.loc[order_book['appl_seq_num'] == row['bid_appl_seq_num'], 'order_qty'] -= row['order_qty']
            order_book.loc[order_book['appl_seq_num'] == row['offer_appl_seq_num'], 'order_qty'] -= row['order_qty']
            order_book = order_book[order_book['order_qty'] > 0]

    # If there are still timestamps left after looping through the dataframe, finalize the order books for those timestamps
    while ts_index < len(timestamps):
        # Concat new orders to the order book
        if not new_orders.empty:
            order_book = pd.concat([order_book, new_orders], ignore_index=True)
            new_orders = new_orders.iloc[0:0]  # Clear new_orders DataFrame

        bid_side = order_book[order_book['side'] == 1].sort_values(by='price', ascending=False).head(10)
        offer_side = order_book[order_book['side'] == 2].sort_values(by='price', ascending=True).head(10)
        order_books[timestamps[ts_index]] = (bid_side, offer_side)
        ts_index += 1

    return order_books


timestamps = ['2020-01-10 09:30:00', '2020-01-10 10:30:00', '2020-01-10 13:30:00']
for stock_code in  ['000069','000566','000876','002304','002841','002918']:
    df = pd.read_csv('/workspaces/quant-project/data/sz_level3/'+stock_code+'/'+stock_code+'_20200110.csv.gz', compression='gzip')
    df['transact_time'] = pd.to_datetime(df['transact_time'], format="%Y%m%d%H%M%S%f")
    df['price'] = df['price'] / 10000  # Data cleaning, restore the price to its real value
    df['order_qty'] = df['order_qty'] / 100  # Data cleaning, restore the order quantity to its real value
    df['stock_code'] = stock_code

    order_books = create_order_book(df, timestamps)

    for timestamp, (bid_side, offer_side) in order_books.items():
        print('Stock Code:', stock_code)
        print("Order book at", timestamp)
        print("Bid side:")
        print(bid_side)
        print("Offer side:")
        print(offer_side)

In [None]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import mplfinance as mpf
import matplotlib as mpl# 用于设置曲线参数
from cycler import cycler# 用于定制线条颜色


# In[2]:


#数据清洗：丢弃行，或用上一行的值填充
def data_wash(dataset,keepTime=False):
    if keepTime:
        dataset.fillna(axis=1,method='ffill')
    else:
        dataset.dropna()
    return dataset


# In[3]:


def import_csv(stock_code):
    #time设为index的同时是否保留时间列
    df = pd.read_csv(stock_code + '.csv')
    #清洗数据
    df=data_wash(df,keepTime=False)
    df.rename(
            columns={
            'Timestamp': 'Date', 'Bid Price': 'Open', 
            'Offer Price': 'High', 'Bid Quantity': 'Low', 
            'Offer Quantity': 'Close'}, 
            inplace=True)
    # df['Date'] = pd.to_datetime(df['Date'],format='%Y%m%d')    
    df.set_index(df['Date'], inplace=True)
    return df


# In[361]:


def draw_Kline(df,period,symbol):
    # 设置基本参数
    # type:#绘制图形的类型，有candle, renko, ohlc, line等
    # 此处选择candle,即K线图
    # mav(moving average):均线类型,此处设置7,30,60日线
    # volume:布尔类型，设置是否显示成交量，默认False
    # title:设置标题
    # y_label:设置纵轴主标题
    # y_label_lower:设置成交量图一栏的标题
    # figratio:设置图形纵横比
    # figscale:设置图形尺寸(数值越大图像质量越高)
    kwargs = dict(
        type='candle', 
        mav=(7, 30, 60), 
        volume=True, 
        title='\nA_stock %s candle_line' % (symbol),    
        ylabel='OHLC Candles', 
        ylabel_lower='Shares\nTraded Volume', 
        figratio=(15, 10), 
        figscale=2)

    # 设置marketcolors
    # up:设置K线线柱颜色，up意为收盘价大于等于开盘价
    # down:与up相反，这样设置与国内K线颜色标准相符
    # edge:K线线柱边缘颜色(i代表继承自up和down的颜色)，下同。详见官方文档)
    # wick:灯芯(上下影线)颜色
    # volume:成交量直方图的颜色
    # inherit:是否继承，选填
    mc = mpf.make_marketcolors(
        up='red', 
        down='green', 
        edge='i', 
        wick='i', 
        volume='in', 
        inherit=True)

    # 设置图形风格
    # gridaxis:设置网格线位置
    # gridstyle:设置网格线线型
    # y_on_right:设置y轴位置是否在右
    s = mpf.make_mpf_style(
        gridaxis='both', 
        gridstyle='-.', 
        y_on_right=False, 
        marketcolors=mc)

    # 设置均线颜色，配色表可见下图
    # 建议设置较深的颜色且与红色、绿色形成对比
    # 此处设置七条均线的颜色，也可应用默认设置
    mpl.rcParams['axes.prop_cycle'] = cycler(
        color=['dodgerblue', 'deeppink', 
        'navy', 'teal', 'maroon', 'darkorange', 
        'indigo'])
    
    # 设置线宽
    mpl.rcParams['lines.linewidth'] = .5

    # 图形绘制
    # show_nontrading:是否显示非交易日，默认False
    # savefig:导出图片，填写文件名及后缀
    mpf.plot(df, 
        **kwargs, 
        style=s, 
        show_nontrading=False,)
    mpf.plot(df, 
        **kwargs, 
        style=s, 
        show_nontrading=False,
        savefig='A_stock-%s %s_candle_line'
        %(symbol, period) + '.jpg')
    plt.show()


# In[163]:


import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset,DataLoader
import math
import numpy as np
import os


# In[320]:


#读取数据切割数据集并保存
TRAIN_WEIGHT=0.9
SEQ_LEN=49
N_Pre=10
LEARNING_RATE=0.00001
BATCH_SIZE=4
EPOCH=2

symbol = '00069'
data = import_csv(symbol)
# df_draw=data[-period:]
# draw_Kline(df_draw,period,symbol)
data.drop(['Num'],axis=1,inplace = True)   
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_size=int(TRAIN_WEIGHT*(data.shape[0]))
train_path="stock_daily/stock_train.csv"
test_path="stock_daily/stock_test.csv"
Train_data=data[:train_size+SEQ_LEN]
Test_data=data[train_size-SEQ_LEN:]
Train_data.to_csv(train_path,sep=',',index=False,header=False)
Test_data.to_csv(test_path,sep=',',index=False,header=False)


# In[321]:


mean_list=[]
std_list=[]


# In[358]:


#完成数据集类
class Stock_Data(Dataset):
    def __init__(self,train=True,transform=None):        
        if train==True:
            train_path="stock_daily/stock_train.csv"
            with open(train_path) as f:
                self.data = np.genfromtxt(f,delimiter = ",")
                #可以注释
                #addi=np.zeros((self.data.shape[0],1))
                #self.data=np.concatenate((self.data,addi),axis=1)
                self.data=self.data[:,1:5]
            #for i in range(self.data.shape[0]-SEQ_LEN):
            #    self.data[i+SEQ_LEN,1]=(sum(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,1])/N_Pre-self.data[i+SEQ_LEN-1,1])/self.data[i+SEQ_LEN-1,1]
            #    self.data[i+SEQ_LEN,2]=(sum(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,2])/N_Pre-self.data[i+SEQ_LEN-1,2])/self.data[i+SEQ_LEN-1,2]
            self.label=torch.rand(self.data.shape[0]-SEQ_LEN,1)
            # y = []
            # for i in range(self.data.shape[0]-SEQ_LEN):
            #     y.append((np.mean(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,0])-self.data[i+SEQ_LEN-1,0])/self.data[i+SEQ_LEN-1,0])
            #    self.label[i,:]=(np.mean(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,0])-self.data[i+SEQ_LEN-1,0])/self.data[i+SEQ_LEN-1,0]
            # y = np.array(y).astype(np.float64)
            # y_std = np.std(y)
             #y_mean = np.mean(y)
             #self.label = (self.label-y_mean)/y_std
            for i in range(len(self.data[0])):
                mean_list.append(np.mean(self.data[:,i]))
                std_list.append(np.std(self.data[:,i]))
                self.data[:,i]=(self.data[:,i]-np.mean(self.data[:,i]))/(np.std(self.data[:,i])+1e-8)
            # mean_list.append(y_mean)
            # std_list.append(y_std)
            self.value=torch.rand(self.data.shape[0]-SEQ_LEN,SEQ_LEN,self.data.shape[1])
            for i in range(self.data.shape[0]-SEQ_LEN):                  
                self.value[i,:,:]=torch.from_numpy(self.data[i:i+SEQ_LEN,:].reshape(SEQ_LEN,self.data.shape[1]))   
                self.label[i,:]=self.data[i+SEQ_LEN,0] 
            self.data=self.value
        else:
            test_path="stock_daily/stock_test.csv"
            with open(test_path) as f:
                self.data = np.genfromtxt(f,delimiter = ",")
                #addi=np.zeros((self.data.shape[0],1))
                #self.data=np.concatenate((self.data,addi),axis=1)
                self.data=self.data[:,1:5]
            #for i in range(self.data.shape[0]-SEQ_LEN):
            #    self.data[i+SEQ_LEN,1]=(sum(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,1])/N_Pre-self.data[i+SEQ_LEN-1,1])/self.data[i+SEQ_LEN-1,1]
            #    self.data[i+SEQ_LEN,2]=(sum(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,2])/N_Pre-self.data[i+SEQ_LEN-1,2])/self.data[i+SEQ_LEN-1,2]
            self.label=torch.rand(self.data.shape[0]-SEQ_LEN,1)           
            #for i in range(self.data.shape[0]-SEQ_LEN):
            #    self.label[i,:]=(np.mean(self.data[i+SEQ_LEN:i+SEQ_LEN+N_Pre,0])-self.data[i+SEQ_LEN-1,0])/self.data[i+SEQ_LEN-1,0]
            #self.label = (self.label-mean_list[len(self.data[0])])/std_list[len(self.data[0])]
            for i in range(len(self.data[0])):
                self.data[:,i]=(self.data[:,i]-mean_list[i])/(std_list[i]+1e-8)
            self.value=torch.rand(self.data.shape[0]-SEQ_LEN,SEQ_LEN,self.data.shape[1])
            for i in range(self.data.shape[0]-SEQ_LEN):                  
                self.value[i,:,:]=torch.from_numpy(self.data[i:i+SEQ_LEN,:].reshape(SEQ_LEN,self.data.shape[1]))    
                self.label[i,:]=self.data[i+SEQ_LEN,0]
            self.data=self.value
    def __getitem__(self,index):
        return self.data[index],self.label[index]
    def __len__(self):
        return len(self.data[:,0])


# In[388]:


stock_train=Stock_Data(train=True)
stock_test=Stock_Data(train=False)

#LSTM模型
class LSTM(nn.Module):
    def __init__(self,dimension):
        super(LSTM,self).__init__()
        self.lstm=nn.LSTM(input_size=dimension,hidden_size=128,num_layers=3,batch_first=True)
        self.linear1=nn.Linear(in_features=128,out_features=16)
        self.linear2=nn.Linear(16,1)
    def forward(self,x):
        out,_=self.lstm(x)
        x=out[:,-1,:]        
        x=self.linear1(x)
        x=self.linear2(x)
        return x


# In[391]:


#传入tensor进行位置编码
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len=SEQ_LEN):
        super(PositionalEncoding,self).__init__()
        #序列长度，dimension d_model
        pe=torch.zeros(max_len,d_model)
        position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)
        pe=pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe',pe)
        
    def forward(self,x):
        return x+self.pe[:x.size(0),:]


# In[392]:


class TransAm(nn.Module):
    def __init__(self,feature_size=4,num_layers=6,dropout=0.1):
        super(TransAm,self).__init__()
        self.model_type='Transformer'
        self.src_mask=None
        self.pos_encoder=PositionalEncoding(feature_size)
        self.encoder_layer=nn.TransformerEncoderLayer(d_model=feature_size,nhead=4,dropout=dropout)
        self.transformer_encoder=nn.TransformerEncoder(self.encoder_layer,num_layers=num_layers)
        #全连接层代替decoder
        self.decoder=nn.Linear(feature_size,1)
        self.linear1=nn.Linear(SEQ_LEN,1)
        self.init_weights()
        self.src_key_padding_mask=None
    
    def init_weights(self):
        initrange=0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange,initrange)
        
    def forward(self,src,seq_len=SEQ_LEN):       
        src=self.pos_encoder(src)
        #print(src)
        #print(self.src_mask)
        #print(self.src_key_padding_mask)
        #output=self.transformer_encoder(src,self.src_mask,self.src_key_padding_mask)
        output=self.transformer_encoder(src)
        output=self.decoder(output)
        output=np.squeeze(output)
        output=self.linear1(output)
        return output


# In[394]:

lstm_path="./model_lstm/epoch_"
transformer_path="./model_transformer/epoch_"
# save_path=lstm_path
save_path=transformer_path
def train(epoch):
    model.train()
    global loss_list
    global iteration
    dataloader=DataLoader(dataset=stock_train,batch_size=BATCH_SIZE,shuffle=False,drop_last=True)
    for i,(data,label) in enumerate(dataloader):
        iteration=iteration+1
        data,label = data.to(device),label.to(device)
        optimizer.zero_grad()
        output=model.forward(data)
        loss=criterion(output,label)
        loss.backward()        
        optimizer.step()
        if i%20==0:
            loss_list.append(loss.item())
            print("epoch=",epoch,"iteration=",iteration,"loss=",loss.item())
        if epoch%EPOCH==0:
            torch.save(model.state_dict,save_path+str(epoch)+"_Model.pkl")
            torch.save(optimizer.state_dict,save_path+str(epoch)+"_Optimizer.pkl")


# In[395]:


def test():
    model.eval()
    global accuracy_list
    global predict_list
    dataloader=DataLoader(dataset=stock_test,batch_size=BATCH_SIZE,shuffle=False,drop_last=True)
    for i,(data,label) in enumerate(dataloader):
        with torch.no_grad():            
            data,label=data.to(device),label.to(device)
            optimizer.zero_grad()
            predict=model.forward(data)
            predict_list.append(predict)
            loss=criterion(predict,label)
            accuracy_fn=nn.MSELoss()
            accuracy=accuracy_fn(predict,label)
            accuracy_list.append(accuracy.item())
    print("test_data MSELoss:(pred-real)/real=",np.mean(accuracy_list))


# In[396]:


def loss_curve(loss_list):
    x=np.linspace(1,len(loss_list),len(loss_list))
    x=20*x
    plt.cla()  # 清除axes，即当前 figure 中的活动的axes，但其他axes保持不变。
    plt.clf()  # 清除当前 figure 的所有axes，但是不关闭这个 window，所以能继续复用于其他的 plot
    plt.plot(x,np.array(loss_list),label="train_loss")
    plt.ylabel("MSELoss")
    plt.xlabel("iteration")
    fig = plt.gcf()
    fig.savefig("train_loss.png",dpi=300)
    plt.show()


# In[397]:


def contrast_lines(predict_list):
    real_list=[]
    prediction_list=[]
    dataloader=DataLoader(dataset=stock_test,batch_size=4,shuffle=False,drop_last=True)
    date=[]
    for i,(data,label) in enumerate(dataloader):
        for idx in range(BATCH_SIZE):
            real_list.append(np.array(label[idx]*std_list[0]+mean_list[0]))
            #real_list.append(np.array(label[idx]))
            date.append(data[idx][0])
    for item in predict_list:
        item=item.to("cpu")
        for idx in range(BATCH_SIZE):
            prediction_list.append(np.array(item[idx]*std_list[0]+mean_list[0]))
            #prediction_list.append(np.array((item[idx])))
    x=np.linspace(1,len(real_list),len(real_list))
    plt.cla()  # 清除axes，即当前 figure 中的活动的axes，但其他axes保持不变。
    plt.clf()  # 清除当前 figure 的所有axes，但是不关闭这个 window，所以能继续复用于其他的 plot
    plt.plot(x,np.array(real_list),label="real")
    plt.plot(x,np.array(prediction_list),label="prediction")
    plt.legend()
    plt.savefig("transformer_Pre.png",dpi=300)
    plt.show()


#选择模型为LSTM或Transformer，注释掉一个

#model=LSTM(dimension=4)
#save_path=lstm_path
model=TransAm(feature_size=4)
save_path=transformer_path

model=model.to(device)
criterion=nn.MSELoss()

#if os.path.exists("./model_lstm/LSTM_"+str(EPOCH)+"_Model.pkl"):
#    model.load_state_dict(torch.load("./model_lstm/epoch_"+str(EPOCH)+"_Model.pkl"))
#optimizer=optim.Adam(model.parameters(),lr=LEARNING_RATE)
#if os.path.exists("./model_lstm/LSTM_"+str(EPOCH)+"_Optimizer.pkl"):
#    optimizer.load_state_dict(torch.load("./model_lstm/epoch_"+str(EPOCH)+"_Optimizer.pkl"))

if os.path.exists("./model_transformer/TRANSFORMER_"+str(EPOCH)+"_Model.pkl"):
    model.load_state_dict(torch.load("./model_transformer/epoch_"+str(EPOCH)+"_Model.pkl"))
optimizer=optim.Adam(model.parameters(),lr=LEARNING_RATE)
if os.path.exists("./model_transformer/TRANSFORMER_"+str(EPOCH)+"_Optimizer.pkl"):
    optimizer.load_state_dict(torch.load("./model_transformer/epoch_"+str(EPOCH)+"_Optimizer.pkl"))

if __name__=="__main__":
    symbol = '00069'
    period = 50
    data = import_csv(symbol)
    # df_draw=data[-period:]
    # draw_Kline(df_draw,period,symbol)
    data.drop(['Num','Date'],axis=1,inplace = True)    
    iteration=0
    loss_list=[]
    #开始训练神经网络
    for epoch in range(1,EPOCH+1):
        predict_list=[]
        accuracy_list=[]
        train(epoch)
        test()
    #绘制损失函数下降曲线    
    loss_curve(loss_list)

#In[]
#绘制测试集pred-real对比曲线
contrast_lines(predict_list)
# %%
