In [1]:
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# 构建模型
def create_model():
    model = Sequential()
    model.add(Dense(units=12,input_dim=8,activation='relu'))
    model.add(Dense(units=8,activation='relu'))
    model.add(Dense(units=1,activation='sigmoid'))
    
    #编译模型
    model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
    
    return model

In [3]:
seed = 7
np.random.seed(7)

In [4]:
#导入数据
dataset = np.loadtxt('./pima-indians-diabetes.data.csv',delimiter=',')
X = dataset[:,:8]
y = dataset[:,8]

In [5]:
# 创建模型 for sklearn
model = KerasClassifier(build_fn=create_model,epochs=150,batch_size=10,verbose=0)

In [6]:
kfold = StratifiedKFold(n_splits=10,shuffle=True,random_state=seed)
results = cross_val_score(estimator=model,X=X,y=y,cv=kfold,n_jobs=-1)

In [7]:
# 是10折交叉验证的accuracy的结果
results

array([0.64935066, 0.6883117 , 0.71428572, 0.79220779, 0.7922078 ,
       0.68831168, 0.35064936, 0.64935066, 0.71052631, 0.68421054])

In [8]:
print(results.mean())

0.6719412205352492


# 网络搜索

In [9]:
def create_model_grid(optimizer='adam',init='glorot_uniform'):
    model = Sequential()
    model.add(Dense(units=12,kernel_initializer=init,input_dim=8,activation='relu'))
    model.add(Dense(units=8,kernel_initializer=init,activation='relu'))
    model.add(Dense(units=1,kernel_initializer=init,activation='sigmoid'))
    
    model.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
    return model

In [10]:
model_grid = KerasClassifier(build_fn=create_model_grid,verbose=1)

In [11]:
params = {
    'optimizer':['rmsprop','adam'],
    'init':['glorot_uniform','normal','uniform'],
    'epochs':[50,100,150,200],
    'batch_size':[5,10,20]
}

In [12]:
grid = GridSearchCV(estimator=model_grid,param_grid=params,n_jobs=-1)

In [13]:
grid.fit(X,y)



KeyboardInterrupt: 