In [20]:
import numpy as np
from sklearn import datasets

## Load Data

Adapted from [sklearn docs](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)

In [21]:
iris = datasets.load_iris()

In [22]:
iris_X = iris.data
iris_y = iris.target

In [23]:
iris_X[:5]

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

In [24]:
iris_y[:5]

array([0, 0, 0, 0, 0])

In [25]:
np.unique(iris_y)

array([0, 1, 2])

## Split iris data into train and test

In [26]:
np.random.seed(42)
indices = np.random.permutation(len(iris_X))
iris_X_train = iris_X[indices[:-10]]
iris_y_train = iris_y[indices[:-10]]
iris_X_test  = iris_X[indices[-10:]]
iris_y_test  = iris_y[indices[-10:]]

## Fit nearest neighbor classifer

In [27]:
from sklearn.neighbors import KNeighborsClassifier

In [28]:
knn = KNeighborsClassifier()

In [29]:
knn.fit(iris_X_train, iris_y_train) 

KNeighborsClassifier()

In [30]:
# check prediction
knn.predict(iris_X_test)

array([1, 1, 2, 2, 0, 1, 1, 0, 1, 2])

In [31]:
iris_y_test

array([1, 1, 2, 2, 0, 1, 2, 0, 1, 2])

## Pickle model to use in REST API

Adapted from [sklearn docs](http://scikit-learn.org/stable/modules/model_persistence.html)

In [32]:
#from sklearn.externals import joblib

Library is not available anymore. Using pickle instead

In [33]:
import pickle

In [34]:
#pickle.dump(knn, 'iris_knn_model.pkl') 

In [16]:
# save the model to disk
filename = 'iris_knn_model.pkl'
pickle.dump(knn, open(filename, 'wb'))

## Load pickled model and use it to predict

In [32]:
#knn_from_pkl = joblib.load('iris_knn_model.pkl')

In [18]:
knn_from_pkl = pickle.load(open(filename, 'rb'))

In [19]:
knn_from_pkl

KNeighborsClassifier()

In [35]:
# Get 1 test case
test_case = iris_X_test[:1]

In [36]:
# columns correspond to [Sepal Length, Sepal Width, Petal Length and Petal Width]
test_case

array([[6.3, 2.3, 4.4, 1.3]])

In [37]:
test_target = iris_y_test[:1]

In [38]:
test_target

array([1])

In [39]:
knn_from_pkl.predict(test_case)

array([1])

In [40]:
type(test_case)

numpy.ndarray