# Import Library

In [None]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
# Set our RNG seed for reproducibility.
RANDOM_STATE_SEED = 123
np.random.seed(RANDOM_STATE_SEED)

# Get Dataset

In [None]:
iris = load_iris()
Xdata, ydata = iris['data'], iris['target']
print('IRIS data shape: {}'.format(Xdata.shape))
print('IRIS target shape: {}'.format(ydata.shape))

In [None]:
# unique values in y
print(np.unique(ydata))

In [None]:
# TSNE
tsne = TSNE(n_components=2, random_state=RANDOM_STATE_SEED)
Xdata_dr = tsne.fit_transform(Xdata)

In [None]:
# Plot our dimensionality-reduced (via PCA) dataset.
plt.figure(figsize=(5, 3), dpi=130)
plt.scatter(x=Xdata_dr[:, 0], y=Xdata_dr[:, 1], c=ydata, cmap='viridis', s=10, alpha=8/10)
plt.xlabel('DR 1')
plt.ylabel('DR 2')
plt.show()

In [None]:
# Isolate our examples for our labeled dataset.
n_labeled_examples = Xdata.shape[0]
training_indices = np.random.randint(low=0, high=n_labeled_examples + 1, size=3)
print('Training indices: {}'.format(training_indices))
Xdata_train = Xdata[training_indices]
ydata_train = ydata[training_indices]
# Isolate the non-training examples we'll be querying.
X_pool = np.delete(Xdata, training_indices, axis=0)
y_pool = np.delete(ydata, training_indices, axis=0)
print('Pool X , y shapes: {}, {}'.format(X_pool.shape, y_pool.shape))

# Active learning with pool-based sampling

<img src="../PoolBasedLearningcycle.png" alt="https://burrsettles.com/pub/settles.activelearning.pdf" style="float: left; margin-right: 10px" />

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from modAL.models import ActiveLearner

In [None]:
# set the classifier and active learning model.
knn = KNeighborsClassifier(n_neighbors=3)
learner = ActiveLearner(estimator=knn, X_training=Xdata_train, y_training=ydata_train)

In [None]:
# Isolate the data we'll need for plotting.
predictions = learner.predict(Xdata)
is_correct = (predictions == ydata)
# Record our learner's score on the raw data.
unqueried_score = learner.score(Xdata, ydata)

# Plot our classification results.
fig, ax = plt.subplots(figsize=(8.5, 6), dpi=130)
ax.scatter(x=Xdata_dr[:, 0][is_correct],  y=Xdata_dr[:, 1][is_correct],  c='g', marker='+', label='Correct',   alpha=8/10)
ax.scatter(x=Xdata_dr[:, 0][~is_correct], y=Xdata_dr[:, 1][~is_correct], c='r', marker='x', label='Incorrect', alpha=8/10)
ax.set_xlabel('t-SNE 1')
ax.set_ylabel('t-SNE 2')
ax.legend()
ax.set_title("base AL (Accuracy: {score:.3f})".format(score=unqueried_score))
plt.show()

# Update our model by pool-based sampling our “unlabeled” dataset

In [None]:
N_QUERIES = 20
history = [unqueried_score]

for index in range(N_QUERIES):
  # get query
  query_index, query_instance = learner.query(X_pool)
  # teach AL model
  X, y = X_pool[query_index].reshape(1, -1), y_pool[query_index].reshape(1, )
  learner.teach(X=X, y=y)
  # remove queired instance from pool
  X_pool, y_pool = np.delete(X_pool, query_index, axis=0), np.delete(y_pool, query_index)
  # calculate model's accuracy.
  model_acc = learner.score(X, y)
  print(f'Accuracy after query {index + 1}: {model_acc:0.3f}')
  # save performance
  history.append(model_acc)

# Model performance

In [None]:
fig, ax = plt.subplots(figsize=(4, 3), dpi=200)

ax.plot(history)
ax.scatter(range(len(history)), history, s=13)

ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=5, integer=True))
ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=10))
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))

ax.set_ylim([0, 1])

ax.set_xlabel('Queries')
ax.set_ylabel('Acc')

plt.show()

In [None]:
# Isolate the data we'll need for plotting.
predictions = learner.predict(Xdata)
is_correct = (predictions == ydata)
print(predictions.shape, is_correct.shape)

fig, ax = plt.subplots(figsize=(4, 3), dpi=200)
ax.scatter(
            x=Xdata_dr[:, 0][is_correct],  y=Xdata_dr[:, 1][is_correct],  
            c='b', marker='*', label='Correct')
ax.scatter(
            x=Xdata_dr[:, 0][~is_correct], y=Xdata_dr[:, 1][~is_correct], 
            c='r', marker='o', label='Incorrect')
ax.set_title('{count} AL queries --> {acc:.2f}'.format(count=N_QUERIES, acc=history[-1]))
ax.legend(loc='best')
plt.show()