In [6]:
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
def generate_df_affect_by_n_days(series, n, index=False):
    if len(series) <= n:
        raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))
    df = pd.DataFrame()
    for i in range(n):
        df['c%d' % i] = series.tolist()[i: - (n - i)]
        
    df['y'] = series.tolist()[n:]
    if index:
        df.index = series.index[n:]
    return df
def readData(column='high', n=30, all_too=True, index=False, train_end=-300):
    df = pd.read_csv("399300.csv", index_col=0)
    df.sort_index(inplace=True)

    df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"), df.index))

    df_column = df[column].copy()
    df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]
    df_generate_from_df_column_train = generate_df_affect_by_n_days(df_column_train, n, index=index)
    
    df_column.to_json('df_column.json')
    df_column_train.to_json('df_column_train.json')
    df_column_test.to_json('df_column_test.json')
    df_generate_from_df_column_train.to_json('df_generate_from_df_column_train.json')

    if all_too:
        return df_generate_from_df_column_train, df_column, df.index.tolist()
    return df_generate_from_df_column_train


class RNN(nn.Module):
    def __init__(self, input_size):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(
            input_size=input_size,
            hidden_size=64,
            num_layers=1,
            batch_first=True
        )
        self.out = nn.Sequential(
            nn.Linear(64, 1)
        )

    def forward(self, x):
        r_out, (h_n, h_c) = self.rnn(x, None)  # None 表示 hidden state 会用全0的 state
        out = self.out(r_out)
        return out
class TrainSet(Dataset):
    def __init__(self, data):
        # 定义好 image 的路径
        self.data, self.label = data[:, :-1].float(), data[:, -1].float()

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)
n = 30
LR = 0.0001
EPOCH = 100
train_end = -500
# 数据集建立
df, df_all, df_index = readData('收盘价', n=n, train_end=train_end)

len(df_index)

df


Unnamed: 0,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,...,c21,c22,c23,c24,c25,c26,c27,c28,c29,y
0,1316.4600,1302.0800,1292.7100,1272.6500,1281.2600,1249.8100,1205.1500,1186.4300,1201.8800,1152.1500,...,1236.9700,1246.8400,1216.7900,1241.0200,1235.0900,1254.6100,1253.1200,1256.0000,1244.6000,1227.2300
1,1302.0800,1292.7100,1272.6500,1281.2600,1249.8100,1205.1500,1186.4300,1201.8800,1152.1500,1149.4800,...,1246.8400,1216.7900,1241.0200,1235.0900,1254.6100,1253.1200,1256.0000,1244.6000,1227.2300,1245.8700
2,1292.7100,1272.6500,1281.2600,1249.8100,1205.1500,1186.4300,1201.8800,1152.1500,1149.4800,1108.6100,...,1216.7900,1241.0200,1235.0900,1254.6100,1253.1200,1256.0000,1244.6000,1227.2300,1245.8700,1275.6100
3,1272.6500,1281.2600,1249.8100,1205.1500,1186.4300,1201.8800,1152.1500,1149.4800,1108.6100,1109.9900,...,1241.0200,1235.0900,1254.6100,1253.1200,1256.0000,1244.6000,1227.2300,1245.8700,1275.6100,1288.4400
4,1281.2600,1249.8100,1205.1500,1186.4300,1201.8800,1152.1500,1149.4800,1108.6100,1109.9900,1183.1500,...,1235.0900,1254.6100,1253.1200,1256.0000,1244.6000,1227.2300,1245.8700,1275.6100,1288.4400,1328.3900
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3824,3921.0017,3913.4462,3913.0688,3944.1626,3931.2495,3926.8520,3930.7981,3959.3953,3976.9490,3993.5752,...,4128.0733,4099.3510,4073.6696,4105.0122,4120.8508,4143.8334,4217.7015,4227.5666,4102.3966,4104.2034
3825,3913.4462,3913.0688,3944.1626,3931.2495,3926.8520,3930.7981,3959.3953,3976.9490,3993.5752,4021.9676,...,4099.3510,4073.6696,4105.0122,4120.8508,4143.8334,4217.7015,4227.5666,4102.3966,4104.2034,4049.9475
3826,3913.0688,3944.1626,3931.2495,3926.8520,3930.7981,3959.3953,3976.9490,3993.5752,4021.9676,4009.7218,...,4073.6696,4105.0122,4120.8508,4143.8334,4217.7015,4227.5666,4102.3966,4104.2034,4049.9475,4055.8235
3827,3944.1626,3931.2495,3926.8520,3930.7981,3959.3953,3976.9490,3993.5752,4021.9676,4009.7218,4006.7179,...,4105.0122,4120.8508,4143.8334,4217.7015,4227.5666,4102.3966,4104.2034,4049.9475,4055.8235,4053.7529
