Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Integrated sknn into the cifar10 benchmark, runs in 280s on this CPU!…
Browse files Browse the repository at this point in the history
… Fixed error in MLP logging.
  • Loading branch information
alexjc committed Apr 21, 2015
1 parent 1c8c7f5 commit 7e67b18
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 12 additions & 6 deletions examples/bench_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ def load(name):
n_feat = data_train.shape[1]
n_targets = labels_train.max() + 1

net = DBN(
[n_feat, n_feat / 3, n_targets],
epochs=50,
learn_rates=0.03,
verbose=1,
)
import sys
import logging
logging.basicConfig(format="%(message)s", level=logging.DEBUG, stream=sys.stdout)

from sknn.mlp import MultiLayerPerceptronClassifier
net = MultiLayerPerceptronClassifier(
[("Rectifier", n_feat*2/3), ("Rectifier", n_feat*1/3), ("Linear", n_targets)],
n_iter=50,
n_stable=10,
learning_rate=0.005,
valid_size=0.1,
verbose=1)
net.fit(data_train, labels_train)

from sklearn.metrics import classification_report
Expand Down
2 changes: 1 addition & 1 deletion sknn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _fit(self, X, y, test=None):
log.info("Training on dataset of {:,} samples with {:,} total size.".format(num_samples, data_size))
if self.valid_set:
X_v, _ = self.valid_set
log.debug(" - Test: {: <10,} Valid: {: <4,}".format(X.shape[0], X_v.shape[0]))
log.debug(" - Train: {: <9,} Valid: {: <4,}".format(X.shape[0], X_v.shape[0]))
if self.n_iter:
log.debug(" - Terminating loop after {} total iterations.".format(self.n_iter))
if self.n_stable:
Expand Down

0 comments on commit 7e67b18

Please sign in to comment.