In [1]:
from sail.models.torch.onn_hbp import ONNHBPClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.utils import gen_batches
import numpy as np
import torch

from sail.transformers.river.preprocessing import StandardScaler

### 2. Load the Iris dataset


In [2]:
iris = load_iris()
X = iris["data"]
y = iris["target"]
names = iris["target_names"]
feature_names = iris["feature_names"]

# Scale data to have mean 0 and variance 1
# which is importance for convergence of the neural network
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split the data set into training and testing
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=2
)

### 3. Train and test ONN on Iris dataset


In [3]:
batch_size = 1

model_skorch = ONNHBPClassifier(
    input_units=4, output_units=3, hidden_units=50, n_hidden_layers=3
)

for batch in gen_batches(X_train.shape[0], batch_size):
    x_batch = X_train[batch]
    y_batch = y_train[batch]
    partial_fit = model_skorch.partial_fit(x_batch, y_batch)

predict = model_skorch.predict(X_test)

print(partial_fit.score(X_test, y_test))

  epoch    train_loss     dur
-------  ------------  ------
      1        [36m1.1044[0m  0.0154
      2        [36m1.0916[0m  0.0024


      3        1.1683  0.0028
      4        1.1564  0.0025
      5        1.1796  0.0019
      6        1.1024  0.0019
      7        1.1325  0.0023
      8        1.1457  0.0019
      9        1.1591  0.0018
     10        1.1121  0.0021
     11        1.1525  0.0022
     12        1.1448  0.0021
     13        1.1063  0.0020
     14        1.1247  0.0017
     15        [36m1.0658[0m  0.0016
     16        1.1212  0.0019
     17        [36m1.0607[0m  0.0022
     18        1.1220  0.0019
     19        [36m1.0511[0m  0.0020
     20        [36m1.0402[0m  0.0019
     21        [36m1.0341[0m  0.0016
     22        1.0670  0.0021
     23        1.0512  0.0022
     24        1.0724  0.0019
     25        [36m1.0317[0m  0.0017
     26        1.0691  0.0020
     27        [36m1.0183[0m  0.0020
     28        1.0730  0.0017
     29        1.1358  0.0020
     30        [36m0.9972[0m  0.0017
     31        1.0009  0.0019
     32        1.0468  0.0018
     33        1.1797  0.002

#### Classification Report


In [4]:
from sklearn.metrics import classification_report

print(classification_report(y_test, predict))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        14
           1       1.00      0.12      0.22         8
           2       0.53      1.00      0.70         8

    accuracy                           0.77        30
   macro avg       0.84      0.71      0.64        30
weighted avg       0.88      0.77      0.71        30



### 5. Improving the results.

Note, the results of doing mini batch learning are very bad. This is because we only do one single epoch. An easy way to improve this is by running partial fit for each mini-batch several times, i.e, multiple epochs.


In [5]:
model_skorch = ONNHBPClassifier(
    input_units=4, output_units=3, hidden_units=50, n_hidden_layers=3
)
for _ in range(10):  # n_epochs
    for batch in gen_batches(X_train.shape[0], batch_size):
        x_batch = X_train[batch]
        y_batch = y_train[batch]
        partial_fit = model_skorch.partial_fit(x_batch, y_batch)
        # Shuffling the dataset
        permutation = torch.randperm(X_train.shape[0])
        X_train = X_train[permutation]
        y_train = y_train[permutation]

# Note how the results improved considerably
print("Accuracy after 10 epochs", partial_fit.score(X_test, y_test))

  epoch    train_loss     dur
-------  ------------  ------
      1        [36m1.1214[0m  0.0022
      2        1.1965  0.0021
      3        1.1261  0.0020
      4        [36m1.1113[0m  0.0023
      5        1.1471  0.0024
      6        1.1163  0.0019
      7        1.1306  0.0024
      8        [36m1.0858[0m  0.0020
      9        1.0950  0.0021
     10        1.1333  0.0019
     11        1.1442  0.0020
     12        1.0961  0.0017
     13        [36m1.0722[0m  0.0017
     14        1.0840  0.0015
     15        [36m1.0715[0m  0.0017
     16        1.1183  0.0017
     17        1.1100  0.0018
     18        [36m1.0677[0m  0.0020
     19        1.1053  0.0018
     20        1.0831  0.0016
     21        [36m1.0513[0m  0.0017
     22        1.0653  0.0017
     23        1.0928  0.0018
     24        [36m1.0325[0m  0.0017
     25        1.1849  0.0017
     26        1.0868  0.0017
     27        1.1084  0.0020
     28        1.0950  0.0016
     29        1.0841  0.001