# *k*-NN Metric Learning in `scikit-learn`

https://pypi.org/project/metric-learn/

Not part of core `scikit-learn`, part of `scikit-learn-contrib`.  
This notebook presents an example of the `LMNN` aalgorithm.

Performance is compared against *vanilla* *k*-NN on the breast-cancer dataset. 

Two options are presented:  
 - A separate metric learning process with the learned metric passed to `KNeighborsClassifier`
 - The metric learning linked to `KNeighborsClassifier` in a pipeline.

The `metric-learn` library will need to be installed using:  
`pip install metric-learn`  or  
`conda install metric-learn`.

In [None]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from metric_learn import LMNN
from sklearn import metrics

## Load the dataset


In [None]:
breast_data = load_breast_cancer()
X = breast_data['data']
y = breast_data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 1/2, random_state=42)

Performance of *vanilla* *k*-NN.  

In [None]:
kNN = KNeighborsClassifier()
kNN.fit(X_train,y_train)
y_pred = kNN.predict(X_test)
knn_acc = metrics.accuracy_score(y_pred,y_test)
print("Hold-Out Testing - basic k-NN: {0:4.2f}".format(knn_acc))

Learn the distance metric using the LMNN algorithm. matrix 

In [None]:
lmnn = LMNN(k=5, learn_rate=1e-6)
lmnn.fit(X_train, y_train)

Run *k*-NN again using the learned metric. 

In [None]:
knnMet = KNeighborsClassifier(metric=lmnn.get_metric())
knnMet.fit(X_train, y_train)
y_pred = knnMet.predict(X_test)
M_knn_acc = metrics.accuracy_score(y_pred,y_test)
print("Hold-Out Testing - Metric Learning k-NN: {0:4.2f}".format(M_knn_acc))

Run *k*-NN with the learned metric, this time using a pipeline.  
This should produce the same result.  

In [None]:
from sklearn.pipeline import make_pipeline

clf = make_pipeline(LMNN(k=5, learn_rate=1e-6), KNeighborsClassifier())
clf.fit(X_train,y_train)

y_pred = clf.predict(X_test)
M_knn_acc = metrics.accuracy_score(y_pred,y_test)
print("Hold-Out Testing - Metric Learning k-NN: {0:4.2f}".format(M_knn_acc))