In [1]:
import tushare as ts
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import *
import talib
import seaborn as sns

from solo.getdata import *

%matplotlib inline
%load_ext autoreload
%autoreload 2



In [2]:
# As usual, a bit of setup


from cs231n.classifiers.cnn import *
from cs231n.data_utils import get_CIFAR10_data
from cs231n.gradient_check import eval_numerical_gradient_array, eval_numerical_gradient
from cs231n.layers import *
from cs231n.fast_layers import *
from cs231n.solver import Solver
from cs231n.convnet import *


plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'


def rel_error(x, y):
  """ returns relative error """
  return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

# 一、获取数据
## 参数说明：
- 'name'：名称
- 'start'：开始时间
- 'data_type'：数据处理方式，'normal'：原始数据, 'carg'：复合增长率, 'grow'：增长速度
- 'ktype' ：时间周期

In [3]:
new_data = data_recieve(name='399300',start = '2004-10-01',data_type = 'carg',ktype='D')

获得原始数据：	 399300
起始时间：	 2005-04-08
结束时间：	 2017-03-07
数据个数：	 2895

      open    close     high     low      volume
0   984.66  1003.45  1003.70  979.53  14762500.0
1  1003.88   995.42  1008.73  992.77  15936100.0
2   993.71   978.70   993.71  978.20  10226200.0
3   987.95  1000.90  1006.50  987.95  16071700.0
4  1004.64   986.97  1006.42  985.58  12945700.0

查看最新数据：

                open      high     close       low    volume
2005-04-11  0.019519  0.005011 -0.008002  0.013517  0.079499
2005-04-12 -0.010131 -0.014890 -0.016797 -0.014676 -0.358300
2005-04-13 -0.005796  0.012871  0.022683  0.009967  0.571620
2005-04-14  0.016894 -0.000079 -0.013917 -0.002399 -0.194503
2005-04-15 -0.021928 -0.023658 -0.013060 -0.013850 -0.195949




# 二、建立特征
## 参数说明：
- 'data'：数据集
- 'time'：多个时间周期
- 'normal'：是否对数据进行归一化
- 'pattern' ：是否加入k模式识别
- 'period' ：预测的天数，默认为1，即用当天的特征预测第二天的值，该值为累加，如设为n，即用n天前的特征预测到今日的累积变化值

In [4]:
ml_datas = data_indicator(data=new_data,time=[5,10,20,30,60],normal=True,period=1)


# 三、设定X和Y

In [5]:
ml_datas['Price'] = pd.qcut(ml_datas['target'],10) #将target进行10等分

#y设定为符合增长率求整
X_ori = ml_datas.drop(['target','Price'],axis=1).values
y_ori = ml_datas['Price'].values
#y_ori = y_ori.reshape([y_ori.shape[0],1])
y_ori = pd.get_dummies(y_ori)
y_columns = y_ori.columns.values
print(y_columns)
y_ori = y_ori.values

[[-9.24, -2.00781], (-2.00781, -1.0489], (-1.0489, -0.554], (-0.554, -0.182], (-0.182, 0.0952], (0.0952, 0.379], (0.379, 0.749], (0.749, 1.33], (1.33, 2.205], (2.205, 9.342]]
Categories (10, object): [[-9.24, -2.00781] < (-2.00781, -1.0489] < (-1.0489, -0.554] < (-0.554, -0.182] ... (0.379, 0.749] < (0.749, 1.33] < (1.33, 2.205] < (2.205, 9.342]]


In [6]:
X_ori.shape,y_ori.shape

((2539, 784), (2539, 10))

# 四、转换格式
## 将原2维数组(N,D)转换成（N,D,H,W）的格式

In [7]:
X,y = data_trans(X_ori,y_ori,depth=28)
X.shape,y.shape

((2511, 28, 28, 28), (2511, 10))

In [8]:
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.33, random_state=42)

small_data = {
  'X_train': X_train,
  'y_train': y_train,
  'X_val': X_test,
  'y_val': y_test,
}

# 五、用TensorFlow做个CNN看下效果

In [9]:
X_train.shape,y_train.shape

((1682, 28, 28, 28), (1682, 10))

In [10]:


learning_rates = [1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1e-0,1e1,1e2,1e3]
weight_scales = [1e-6,1e-5,1e-4,1e-3,1e-2,1e-1,1e-0,1e1,1e2,1e3]



best_val_acc = 0
for learning_rate in learning_rates:
    for weight_scale in weight_scales:
            
            
        model = ManyLayer_BN_ConvNet( weight_scale=weight_scale, dtype=np.float64,
                                        use_batchnorm=True, dropout=1)
        solver = Solver(model, small_data,
                            num_epochs=1, batch_size=50,
                            update_rule='adam',
                            optim_config={
                                'learning_rate': learning_rate
                                },
                            print_every=100,
                            lr_decay=0.95,
                            verbose=False)
        solver.train()
        if solver.best_val_acc > best_val_acc:
                best_val_acc = solver.best_val_acc
                best_model = model
                cache = [learning_rate,weight_scale]
                
        print 'Val_acc:%s,    Lr:%s,    Ws:%s:' % (solver.best_val_acc, learning_rate,weight_scale)
    

print 'done!'
print cache

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (50,) (50,10) 

In [None]:
learning_rate_best,weight_scale_best =  cache


           
final_model = ManyLayer_BN_ConvNet(weight_scale=weight_scale_best, dtype=np.float64,
                                        use_batchnorm=True, dropout=1)         
final_solver = Solver(final_model, 
                            small_data,
                            num_epochs=5, 
                            batch_size=50,
                            update_rule='adam',
                            optim_config={'learning_rate': learning_rate_best},
                            print_every=100,
                            lr_decay=0.95,
                            verbose=False)

final_solver.train()

In [None]:
#see

plt.subplot(2, 1, 1)
plt.plot(final_solver.loss_history, 'o')
plt.xlabel('iteration')
plt.ylabel('loss')

plt.subplot(2, 1, 2)
plt.plot(final_solver.train_acc_history, '-o')
plt.plot(final_solver.val_acc_history, '-o')
plt.legend(['train', 'val'], loc='upper left')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.show()

In [None]:
y_test_pred = np.argmax(final_model.loss(small_data['X_test']), axis=1)
y_val_pred = np.argmax(final_model.loss(small_data['X_val']), axis=1)
print 'Validation set accuracy: ', (y_val_pred == small_data['y_val']).mean()
print 'Test set accuracy: ', (y_test_pred == small_data['y_test']).mean()

Val_acc:0.496389939088,    Lr:1e-06,    Ws:10.0:
Val_acc:0.493084727765,    Lr:1e-06,    Ws:1000.0:
Val_acc:0.528985507246,    Lr:0.0001,    Ws:1e-05:
Val_acc:0.528985507246,    Lr:0.001,    Ws:1e-06:
Val_acc:0.528985507246,    Lr:0.001,    Ws:0.01:
Val_acc:0.528985507246,    Lr:0.1,    Ws:1e-05:
Val_acc:0.528985507246,    Lr:0.1,    Ws:0.001:
Val_acc:0.528985507246,    Lr:0.1,    Ws:0.1:
Val_acc:0.528349553082,    Lr:10.0,    Ws:1e-06:
Val_acc:0.528985507246,    Lr:10.0,    Ws:1e-05:
Val_acc:0.528985507246,    Lr:100.0,    Ws:0.01:
Val_acc:0.528985507246,    Lr:100.0,    Ws:0.1:
Val_acc:0.528985507246,    Lr:100.0,    Ws:100.0:
Val_acc:0.515230810521,    Lr:1000.0,    Ws:1e-05:
Val_acc:0.528985507246,    Lr:1000.0,    Ws:0.1:

Val_acc:0.0012077294686,    Lr:1e-06,    Ws:1e-06:
Val_acc:0.0700483091787,    Lr:1e-06,    Ws:1e-05:
Val_acc:0.00239212116969,    Lr:1e-06,    Ws:0.0001:
Val_acc:0.0796838899391,    Lr:1e-06,    Ws:0.001:
Val_acc:0.099869600224,    Lr:1e-06,    Ws:0.01:
Val_acc:0.301669817265,    Lr:1e-06,    Ws:0.1:
Val_acc:0.00653457490256,    Lr:1e-06,    Ws:1.0:
Val_acc:0.496389939088,    Lr:1e-06,    Ws:10.0:
Val_acc:0.0619223902541,    Lr:1e-06,    Ws:100.0:
Val_acc:0.493084727765,    Lr:1e-06,    Ws:1000.0:
Val_acc:0.0700483091787,    Lr:1e-05,    Ws:1e-06:
Val_acc:0.0700483091787,    Lr:1e-05,    Ws:1e-05:
Val_acc:0.00866268757731,    Lr:1e-05,    Ws:0.0001:
Val_acc:0.134775607365,    Lr:1e-05,    Ws:0.001:
Val_acc:0.000558647809751,    Lr:1e-05,    Ws:0.01:
Val_acc:0.152510852062,    Lr:1e-05,    Ws:0.1:
Val_acc:0.0249013979323,    Lr:1e-05,    Ws:1.0:
Val_acc:0.0159965810171,    Lr:1e-05,    Ws:10.0:
Val_acc:0.12347575206,    Lr:1e-05,    Ws:100.0:
Val_acc:0.0490939111765,    Lr:1e-05,    Ws:1000.0:
Val_acc:0.0012077294686,    Lr:0.0001,    Ws:1e-06:
Val_acc:0.528985507246,    Lr:0.0001,    Ws:1e-05:
Val_acc:0.0024154589372,    Lr:0.0001,    Ws:0.0001:
Val_acc:0.0157004830918,    Lr:0.0001,    Ws:0.001:
Val_acc:0.0178242199351,    Lr:0.0001,    Ws:0.01:
Val_acc:0.0774595089734,    Lr:0.0001,    Ws:0.1:
Val_acc:0.0162547550701,    Lr:0.0001,    Ws:1.0:
Val_acc:0.00290263483395,    Lr:0.0001,    Ws:10.0:
Val_acc:0.0501266073887,    Lr:0.0001,    Ws:100.0:
Val_acc:0.00385656608089,    Lr:0.0001,    Ws:1000.0:
Val_acc:0.528985507246,    Lr:0.001,    Ws:1e-06:
Val_acc:0.0012077294686,    Lr:0.001,    Ws:1e-05:
Val_acc:0.135265700483,    Lr:0.001,    Ws:0.0001:
Val_acc:0.0012077294686,    Lr:0.001,    Ws:0.001:
Val_acc:0.528985507246,    Lr:0.001,    Ws:0.01:
Val_acc:0.313353287125,    Lr:0.001,    Ws:0.1:
Val_acc:0.104236679969,    Lr:0.001,    Ws:1.0:
Val_acc:0.015642138673,    Lr:0.001,    Ws:10.0:
Val_acc:0.0154904431842,    Lr:0.001,    Ws:100.0:
Val_acc:0.270411794908,    Lr:0.001,    Ws:1000.0:
Val_acc:0.0338164251208,    Lr:0.01,    Ws:1e-06:
Val_acc:0.135265700483,    Lr:0.01,    Ws:1e-05:
Val_acc:0.0700483091787,    Lr:0.01,    Ws:0.0001:
Val_acc:0.0012077294686,    Lr:0.01,    Ws:0.001:
Val_acc:0.0157004830918,    Lr:0.01,    Ws:0.01:
Val_acc:0.150957431912,    Lr:0.01,    Ws:0.1:
Val_acc:0.172099990665,    Lr:0.01,    Ws:1.0:
Val_acc:0.0801506452893,    Lr:0.01,    Ws:10.0:
Val_acc:0.0108243482928,    Lr:0.01,    Ws:100.0:
Val_acc:0.128869693575,    Lr:0.01,    Ws:1000.0:
Val_acc:0.135265700483,    Lr:0.1,    Ws:1e-06:
Val_acc:0.528985507246,    Lr:0.1,    Ws:1e-05:
Val_acc:0.135265700483,    Lr:0.1,    Ws:0.0001:
Val_acc:0.528985507246,    Lr:0.1,    Ws:0.001:
Val_acc:0.0036231884058,    Lr:0.1,    Ws:0.01:
Val_acc:0.528985507246,    Lr:0.1,    Ws:0.1:
Val_acc:0.0265087866695,    Lr:0.1,    Ws:1.0:
Val_acc:0.04298233331,    Lr:0.1,    Ws:10.0:
Val_acc:0.123574937571,    Lr:0.1,    Ws:100.0:
Val_acc:0.0327253844897,    Lr:0.1,    Ws:1000.0:
Val_acc:0.135265700483,    Lr:1.0,    Ws:1e-06:
Val_acc:0.135265700483,    Lr:1.0,    Ws:1e-05:
Val_acc:0.135265700483,    Lr:1.0,    Ws:0.0001:
Val_acc:0.0157004830918,    Lr:1.0,    Ws:0.001:
Val_acc:0.0700483091787,    Lr:1.0,    Ws:0.01:
Val_acc:0.135265700483,    Lr:1.0,    Ws:0.1:
Val_acc:0.0549050152862,    Lr:1.0,    Ws:1.0:
Val_acc:0.0211060934911,    Lr:1.0,    Ws:10.0:
Val_acc:0.357573980723,    Lr:1.0,    Ws:100.0:
Val_acc:0.213403463325,    Lr:1.0,    Ws:1000.0:
Val_acc:0.528349553082,    Lr:10.0,    Ws:1e-06:
Val_acc:0.528985507246,    Lr:10.0,    Ws:1e-05:
Val_acc:0.0338164251208,    Lr:10.0,    Ws:0.0001:
Val_acc:0.0338164251208,    Lr:10.0,    Ws:0.001:
Val_acc:0.0700483091787,    Lr:10.0,    Ws:0.01:
Val_acc:0.0012077294686,    Lr:10.0,    Ws:0.1:
Val_acc:0.0036231884058,    Lr:10.0,    Ws:1.0:
Val_acc:0.0012077294686,    Lr:10.0,    Ws:10.0:
Val_acc:0.342223564144,    Lr:10.0,    Ws:100.0:
Val_acc:0.194669070457,    Lr:10.0,    Ws:1000.0:
Val_acc:0.0012077294686,    Lr:100.0,    Ws:1e-06:
Val_acc:0.0157004830918,    Lr:100.0,    Ws:1e-05:
Val_acc:0.0024154589372,    Lr:100.0,    Ws:0.0001:
Val_acc:0.0157004830918,    Lr:100.0,    Ws:0.001:
Val_acc:0.528985507246,    Lr:100.0,    Ws:0.01:
Val_acc:0.528985507246,    Lr:100.0,    Ws:0.1:
Val_acc:0.0338164251208,    Lr:100.0,    Ws:1.0:
Val_acc:0.0700483091787,    Lr:100.0,    Ws:10.0:
Val_acc:0.528985507246,    Lr:100.0,    Ws:100.0:
Val_acc:0.135143177204,    Lr:100.0,    Ws:1000.0:
Val_acc:0.0012077294686,    Lr:1000.0,    Ws:1e-06:
Val_acc:0.515230810521,    Lr:1000.0,    Ws:1e-05:
Val_acc:0.0338164251208,    Lr:1000.0,    Ws:0.0001:
Val_acc:0.0700483091787,    Lr:1000.0,    Ws:0.001:
Val_acc:0.135265700483,    Lr:1000.0,    Ws:0.01:
Val_acc:0.528985507246,    Lr:1000.0,    Ws:0.1:
Val_acc:0.0338164251208,    Lr:1000.0,    Ws:1.0:
Val_acc:0.135265700483,    Lr:1000.0,    Ws:10.0:
Val_acc:0.0338164251208,    Lr:1000.0,    Ws:100.0:
Val_acc:0.135265700483,    Lr:1000.0,    Ws:1000.0: