In [70]:
import pickle

import numpy as np
import torch
import torch.nn.functional as F

from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.pipeline import Pipeline
from torch import nn, optim

from skorch.dataset import CVSplit
from skorch.callbacks import EpochScoring
from skorch.classifier import NeuralNetClassifier
from skorch.net import NeuralNet

In [27]:
X, y = make_classification(n_samples=1000, n_features=20,
                           n_informative=10, n_classes=3,
                           random_state=123)

X = X.astype(np.float32)
y = y.astype(np.int64)

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, stratify=y, train_size=0.8)
print(f"Type X: {X_train.dtype}")
print(f"Type y: {y_train.dtype}")


Type X: float32
Type y: int64


In [59]:
class NetClassifier(nn.Module):
    
    def __init__(self, num_units=64):
        super(NetClassifier, self).__init__()
        self.num_units = num_units
        
        self.linear = nn.Linear(20, self.num_units)
        self.linear2 = nn.Linear(self.num_units, 32)
        self.linear3 = nn.Linear(32, 3)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, x):
        x = F.relu(self.linear(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        x = F.softmax(self.linear3(x), dim=1)
        return x

    
acc_train = EpochScoring(scoring=accuracy_score, lower_is_better=False, on_train=True)
acc_valid = EpochScoring(scoring=accuracy_score, lower_is_better=False, on_train=False)
    
net = NeuralNetClassifier(
    module = NetClassifier,
    #module__num_units = 54,
    criterion=nn.NLLLoss,
    optimizer=optim.Adam,
    #optimizer__lr=1e-3,
    lr=1e-3,
    max_epochs=25,
    device='cuda',
    train_split=CVSplit(cv=0.15, stratified=True),
    callbacks=[acc_train]
)

net.fit(X_train, y_train)
y_class = net.predict(X_test)
y_prob = net.predict_proba(X_test)


  epoch    accuracy_score    train_loss    valid_acc    valid_loss     dur
-------  ----------------  ------------  -----------  ------------  ------
      1            [36m0.3103[0m        [32m1.1075[0m       [35m0.3833[0m        [31m1.0779[0m  0.0411
      2            [36m0.4809[0m        [32m1.0574[0m       [35m0.5333[0m        [31m1.0481[0m  0.0361
      3            [36m0.5588[0m        [32m1.0288[0m       [35m0.5667[0m        [31m1.0235[0m  0.0356
      4            0.5529        [32m1.0006[0m       [35m0.6000[0m        [31m0.9996[0m  0.0351
      5            [36m0.6132[0m        [32m0.9743[0m       0.5667        [31m0.9746[0m  0.0341
      6            0.6088        [32m0.9456[0m       0.5583        [31m0.9480[0m  0.0351
      7            [36m0.6191[0m        [32m0.9174[0m       0.5417        [31m0.9219[0m  0.0321
      8            [36m0.6412[0m        [32m0.8785[0m       0.5500        [31m0.8956[0m  0.0361
      9         

In [36]:
y_class

array([0, 0, 1, 2, 1, 1, 0, 2, 0, 2, 0, 0, 2, 0, 1, 2, 0, 1, 1, 1, 2, 0,
       2, 2, 1, 1, 0, 2, 2, 2, 0, 1, 2, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1,
       2, 2, 1, 2, 0, 0, 1, 2, 1, 0, 2, 0, 2, 1, 0, 1, 1, 2, 2, 0, 2, 1,
       0, 1, 0, 0, 0, 1, 0, 2, 1, 0, 0, 0, 1, 0, 2, 1, 2, 0, 0, 0, 1, 0,
       1, 2, 2, 0, 1, 0, 2, 0, 0, 1, 0, 2, 2, 2, 1, 0, 0, 2, 1, 1, 0, 1,
       0, 0, 1, 2, 1, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 2,
       0, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1,
       2, 2, 2, 2, 2, 1, 0, 2, 2, 2, 0, 2, 0, 2, 0, 0, 2, 1, 1, 1, 1, 0,
       0, 1, 0, 1, 1, 2, 1, 2, 1, 0, 2, 2, 0, 0, 0, 0, 1, 2, 0, 1, 2, 2,
       2, 1], dtype=int64)

In [37]:
y_prob

array([[8.91190708e-01, 5.14783897e-02, 5.73309213e-02],
       [7.14790165e-01, 1.87224895e-01, 9.79849398e-02],
       [1.29795969e-01, 5.23964942e-01, 3.46239090e-01],
       [2.42265776e-01, 3.26228857e-01, 4.31505442e-01],
       [1.89428613e-01, 7.90573239e-01, 1.99981071e-02],
       [2.36641631e-01, 5.70246994e-01, 1.93111360e-01],
       [7.59905636e-01, 2.05117185e-02, 2.19582722e-01],
       [1.92739591e-01, 2.07736343e-03, 8.05182993e-01],
       [7.02628613e-01, 1.33170620e-01, 1.64200738e-01],
       [3.58929396e-01, 3.43587361e-02, 6.06711864e-01],
       [6.05406523e-01, 7.82932267e-02, 3.16300303e-01],
       [6.38577461e-01, 1.32435396e-01, 2.28987187e-01],
       [3.38282287e-02, 2.22261343e-03, 9.63949084e-01],
       [8.48609984e-01, 9.11165550e-02, 6.02733940e-02],
       [5.54128364e-02, 8.85331213e-01, 5.92559539e-02],
       [3.87716182e-02, 1.19336978e-01, 8.41891468e-01],
       [5.08691430e-01, 4.55509275e-02, 4.45757687e-01],
       [5.54309860e-02, 7.53975

## Sklearn Pipeline

In [50]:
pipe = Pipeline([
    ('net', net)
])

pipe.fit(X_train, y_train)
y_class = pipe.predict(X_test)
y_prob = pipe.predict_proba(X_test)


Re-initializing module because the following parameters were re-set: num_units.
Re-initializing optimizer because the following parameters were re-set: lr.
  epoch    accuracy_score    train_loss    valid_acc    valid_loss     dur
-------  ----------------  ------------  -----------  ------------  ------
      1            [36m0.3794[0m        [32m1.0921[0m       [35m0.5000[0m        [31m1.0707[0m  0.0582
      2            [36m0.4559[0m        [32m1.0609[0m       [35m0.5750[0m        [31m1.0428[0m  0.0401
      3            [36m0.5235[0m        [32m1.0251[0m       [35m0.5917[0m        [31m1.0151[0m  0.0356
      4            [36m0.5529[0m        [32m0.9966[0m       [35m0.6167[0m        [31m0.9860[0m  0.0343
      5            [36m0.5574[0m        [32m0.9623[0m       [35m0.6500[0m        [31m0.9553[0m  0.0396
      6            [36m0.5882[0m        [32m0.9358[0m       0.6417        [31m0.9238[0m  0.0346
      7            [36m0.5926[0m   

## GridSearchCV

In [51]:
net.prefixes_

['module',
 'iterator_train',
 'iterator_valid',
 'optimizer',
 'criterion',
 'callbacks',
 'dataset']

In [62]:
pipe = Pipeline([
    ('net', net)
])

grid = {
    'net__module__num_units': [50, 60]
}

gs = GridSearchCV(
    estimator=pipe, 
    param_grid=grid,
    scoring='accuracy',
    cv=3,
    verbose=2
)

gs.fit(X_train, y_train)
predictions = gs.predict(X_test)


Fitting 3 folds for each of 2 candidates, totalling 6 fits
[CV] net__module__num_units=50 .......................................
Re-initializing module because the following parameters were re-set: num_units.
Re-initializing optimizer.
Re-initializing module because the following parameters were re-set: num_units.
Re-initializing optimizer.
  epoch    accuracy_score    train_loss    valid_acc    valid_loss     dur
-------  ----------------  ------------  -----------  ------------  ------
      1            [36m0.3451[0m        [32m1.1052[0m       [35m0.3125[0m        [31m1.0943[0m  0.0317
      2            [36m0.3982[0m        [32m1.0775[0m       [35m0.4250[0m        [31m1.0732[0m  0.0251
      3            [36m0.4690[0m        [32m1.0531[0m       [35m0.4625[0m        [31m1.0542[0m  0.0221
      4            [36m0.5221[0m        [32m1.0383[0m       [35m0.5125[0m        [31m1.0360[0m  0.0256
      5            0.5177        [32m1.0239[0m       0.5125 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.



      6            [36m0.5310[0m        [32m1.0083[0m       [35m0.5250[0m        [31m1.0022[0m  0.0276
      7            [36m0.5686[0m        [32m0.9847[0m       [35m0.5500[0m        [31m0.9865[0m  0.0231
      8            0.5597        [32m0.9761[0m       [35m0.5875[0m        [31m0.9707[0m  0.0251
      9            [36m0.5885[0m        [32m0.9507[0m       [35m0.6000[0m        [31m0.9563[0m  0.0261
     10            [36m0.6195[0m        [32m0.9243[0m       [35m0.6250[0m        [31m0.9432[0m  0.0251
     11            [36m0.6283[0m        [32m0.9185[0m       0.6250        [31m0.9317[0m  0.0231
     12            [36m0.6438[0m        [32m0.8856[0m       [35m0.6375[0m        [31m0.9219[0m  0.0241
     13            0.6173        [32m0.8762[0m       0.6375        [31m0.9127[0m  0.0231
     14            [36m0.6482[0m        [32m0.8548[0m       0.6375        [31m0.9025[0m  0.0221
     15            0.6350        [32m0.8450

[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.8s remaining:    0.0s


      6            [36m0.5629[0m        [32m0.9761[0m       0.5625        [31m0.9683[0m  0.0241
      7            [36m0.6115[0m        [32m0.9422[0m       0.5750        [31m0.9507[0m  0.0226
      8            0.5982        [32m0.9270[0m       [35m0.6000[0m        [31m0.9330[0m  0.0261
      9            0.6093        [32m0.9099[0m       [35m0.6250[0m        [31m0.9161[0m  0.0226
     10            [36m0.6358[0m        [32m0.8868[0m       [35m0.6375[0m        [31m0.8993[0m  0.0216
     11            0.6093        [32m0.8696[0m       0.6375        [31m0.8828[0m  0.0216
     12            [36m0.6645[0m        [32m0.8553[0m       0.6375        [31m0.8665[0m  0.0256
     13            0.6313        [32m0.8365[0m       0.6375        [31m0.8514[0m  0.0281
     14            0.6534        [32m0.8102[0m       0.6375        [31m0.8369[0m  0.0246
     15            [36m0.6667[0m        [32m0.7860[0m       [35m0.6500[0m        [31m0.8232[

[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    4.8s finished



      5            [36m0.5324[0m        [32m0.9688[0m       [35m0.5417[0m        [31m1.0006[0m  0.0321
      6            [36m0.5765[0m        [32m0.9388[0m       0.5417        [31m0.9783[0m  0.0371
      7            [36m0.5809[0m        [32m0.9148[0m       [35m0.5500[0m        [31m0.9555[0m  0.0321
      8            [36m0.5971[0m        [32m0.8950[0m       [35m0.5750[0m        [31m0.9303[0m  0.0346
      9            [36m0.6147[0m        [32m0.8672[0m       [35m0.5917[0m        [31m0.9027[0m  0.0321
     10            0.6088        [32m0.8532[0m       0.5750        [31m0.8759[0m  0.0316
     11            [36m0.6471[0m        [32m0.8238[0m       0.5917        [31m0.8511[0m  0.0321
     12            0.6456        [32m0.8002[0m       [35m0.6167[0m        [31m0.8284[0m  0.0366
     13            [36m0.6529[0m        [32m0.7944[0m       [35m0.6500[0m        [31m0.8067[0m  0.0321
     14            [36m0.6794[0m        [3

In [68]:
confusion_matrix(y_test, predictions)

array([[47,  6, 18],
       [ 6, 51,  6],
       [14,  9, 43]], dtype=int64)

## Saving model

In [71]:
# Save model gs
with open('model.pkl', 'wb') as f:
    pickle.dump(gs, f)


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [72]:
with open('model.pkl', 'rb') as f:
    model = pickle.load(f)
