Skip to content

Commit

Permalink
Merge pull request #55 from keineahnung2345/401-cnn
Browse files Browse the repository at this point in the history
[Bug fix] 401_CNN.py
  • Loading branch information
MorvanZhou committed Nov 12, 2018
2 parents 0f4219c + a9ef65e commit 40bcdb5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tutorial-contents/401_CNN.py
Expand Up @@ -65,7 +65,7 @@ def __init__(self):
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
Expand Down Expand Up @@ -115,7 +115,7 @@ def plot_with_labels(lowDWeights, labels):

if step % 50 == 0:
test_output, last_layer = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
if HAS_SK:
Expand All @@ -129,6 +129,6 @@ def plot_with_labels(lowDWeights, labels):

# print 10 predictions from test data
test_output, _ = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')

0 comments on commit 40bcdb5

Please sign in to comment.