Skip to content

Commit 2bf5051

Browse files
committed
knn.predict(test[3:5])
1 parent cf5048a commit 2bf5051

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

ch46-机器学习-K近邻/2-使用kNN对手写数字OCR.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
matches = result == test_labels
4242
correct = np.count_nonzero(matches)
4343
accuracy = correct * 100.0 / result.size
44-
print('准确率', accuracy) # 准确率91%
44+
print('准确率', accuracy) # 准确率91.76%
4545

4646
''''''
4747
# save the data
@@ -56,5 +56,5 @@
5656

5757

5858
#TODO 怎样预测数字?
59-
# knn.predict?
59+
retval, results=knn.predict(test[3:5])
6060
# Docstring: predict(samples[, results[, flags]]) -> retval, results
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2017/8/8 12:33
3+
# @Author : play4fun
4+
# @File : knn-find_nearest.py
5+
# @Software: PyCharm
6+
7+
"""
8+
knn-find_nearest.py:
9+
http://www.bogotobogo.com/python/OpenCV_Python/python_opencv3_Machine_Learning_Classification_K-nearest_neighbors_k-NN.php
10+
"""
11+
12+
import cv2
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
16+
# Feature set containing (x,y) values of 25 known/training data
17+
trainData = np.random.randint(0, 100, (25, 2)).astype(np.float32)
18+
19+
# Labels each one either Red or Blue with numbers 0 and 1
20+
responses = np.random.randint(0, 2, (25, 1)).astype(np.float32)
21+
22+
# plot Reds
23+
red = trainData[responses.ravel() == 0]
24+
plt.scatter(red[:, 0], red[:, 1], 80, 'r', '^')
25+
26+
# plot Blues
27+
blue = trainData[responses.ravel() == 1]
28+
plt.scatter(blue[:, 0], blue[:, 1], 80, 'b', 's')
29+
30+
# CvKNearest instance
31+
# knn = cv2.KNearest()
32+
knn = cv2.ml.KNearest_create()
33+
# trains the model
34+
knn.train(trainData, responses)#TODO
35+
#TypeError: only length-1 arrays can be converted to Python scalars
36+
37+
38+
# New sample : (x,y)
39+
newcomer = np.random.randint(0, 100, (1, 2)).astype(np.float32)
40+
plt.scatter(newcomer[:, 0], newcomer[:, 1], 80, 'g', 'o')
41+
42+
# Finds the 3nearest neighbors and predicts responses for input vectors
43+
ret, results, neighbours, dist = knn.find_nearest(newcomer, 3)
44+
45+
print("result: ", results, "\n")
46+
print("neighbours: ", neighbours, "\n")
47+
print("distance: ", dist)
48+
49+
plt.show()

0 commit comments

Comments
 (0)