Skip to content

Commit

Permalink
add CNN support
Browse files Browse the repository at this point in the history
  • Loading branch information
MorvanZhou committed Oct 30, 2018
1 parent 6310c50 commit e45cac5
Show file tree
Hide file tree
Showing 18 changed files with 561 additions and 271 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ for _ in range(1000):

## Demo
* A naked and step-by-step [network](/simple_nn.py) without using my module.
* [Train a regressor](/train_regressor.py)
* [Train a classifier](/train_classifier.py)
* [Train regressor](/train_regressor.py)
* [Train classifier](/train_classifier.py)
* [Train CNN](/train_cnn.py)
* [Save and restore a trained net](/save_model.py)


## Download or fork
Download [link](https://github.com/MorvanZhou/simple-neural-networks/archive/master.zip)

Expand Down
Binary file added mnist.npz
Binary file not shown.
18 changes: 10 additions & 8 deletions neuralnets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .nn import layers
from .nn import initializers as init
from .nn import activations as act
from .nn.module import Module
from .nn import losses
from .nn import optimizers as optim
from .nn.variable import Variable
from .nn.saver import Saver
from . import layers
from . import initializers as init
from . import activations as act
from .module import Module
from . import losses
from . import optimizers as optim
from .variable import Variable
from .saver import Saver
from .dataloader import DataLoader
from . import metrics
File renamed without changes.
28 changes: 28 additions & 0 deletions neuralnets/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np


class DataLoader:
def __init__(self, x, y, batch_size):
self.x = x
self.y = y
self.bs = batch_size
self.p = 0
self.bg = self.batch_generator()

def batch_generator(self):
while True:
p_ = self.p + self.bs
if p_ > len(self.x):
self.p = 0
continue
if self.p == 0:
indices = np.random.permutation(len(self.x))
self.x[:] = self.x[indices]
self.y[:] = self.y[indices]
bx = self.x[self.p:p_]
by = self.y[self.p:p_]
self.p = p_
yield bx, by

def next_batch(self):
return next(self.bg)
File renamed without changes.
Loading

0 comments on commit e45cac5

Please sign in to comment.