# Active Learning Example

This notebook gives a basic demo of Active Learning on Iris data set

Data set size = 145 points

Number of classes = 3 (Setosa, Versicolor, Virginica)

Importing the required libraries

In [20]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly
from modAL.models import ActiveLearner
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import warnings
warnings.filterwarnings('ignore')

Reading the dataset independent variables into X and dependent variable in y 

In [2]:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
print(X.shape, y.shape)

(150, 4) (150,)


We intially choose only 5 random datapoints with labels as our trainset. 
The trainset consists of these 5 data points and remaining points are in test set

In [12]:
n_initial = 5
initial_idx = np.random.choice(range(len(X)), size=n_initial, replace=False)
non_idx = [i for i in range(len(X)) if i not in initial_idx]

X_training, y_training = X[initial_idx], y[initial_idx]
X_test, y_test = X[non_idx], y[non_idx]

# initializing the learner
learner = ActiveLearner(
    estimator=LogisticRegression(),
    X_training=X_training, y_training=y_training
)

We shall iterate through the process n_queries times. In each iteration the train set size will inrease by 1 with the model getting trained on the most confused point. The model will be tested in each iteration to identify the most confussed point. Prediction probability metric is used to iedntify the most confused point. 
After 45 iterations we will observe model reaching near 100 % accuracy using only 50 data points for training. Thus model is able to achieve near 100 % accuracy with 1/3rd the amount of total dataset.
This is the power of Active Learning using intelligent sampling.

In [11]:
# active learning
print("initial idxs : ", initial_idx)
performance = []
n_queries = 20
for idx in range(n_queries):
    query_idx, query_instance = learner.query(X_test)
    print("query_idx : ", query_idx)  # index with most confusion
    print("prediction probability : ", learner.predict_proba(X_test[query_idx]), "prediction class : ", learner.predict(X_test[query_idx]))

    y_pred = learner.predict(X_test)
    print(classification_report(y_test, y_pred))
    print(confusion_matrix(y_test, y_pred))
    print(accuracy_score(y_test, y_pred))
    
    performance.append(accuracy_score(y_test, y_pred))
    learner.teach(X_test[query_idx], y_test[query_idx])  # model trained on most confused data point
    
    X_test = np.delete(X_test, query_idx, 0)   # new test set X without the above trained data point
    y_test = np.delete(y_test, query_idx, 0)   # new test set y without the above trained data point

    

initial idxs :  [ 91  59  87 112  16]
query_idx :  [40]
prediction probability :  [[0.48197657 0.38676899 0.13125444]] prediction class :  [0]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        49
           1       0.49      1.00      0.66        47
           2       0.00      0.00      0.00        49

    accuracy                           0.66       145
   macro avg       0.50      0.67      0.55       145
weighted avg       0.50      0.66      0.55       145

[[49  0  0]
 [ 0 47  0]
 [ 0 49  0]]
0.6620689655172414
query_idx :  [93]
prediction probability :  [[0.28569861 0.50482329 0.20947811]] prediction class :  [1]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        48
           1       0.49      1.00      0.66        47
           2       0.00      0.00      0.00        49

    accuracy                           0.66       144
   macro avg       0.50      0.67      0.55

In [18]:
x_axis = np.arange(n_initial, n_queries+n_initial)
fig = go.Figure(data=go.Scatter(x=x_axis, y=performance))
fig.show()

In [21]:
plotly.offline.plot(fig, filename = 'performance.html', auto_open=True)

'performance.html'