In [2]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [3]:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn

In [4]:
from skorch import NeuralNetClassifier

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X

In [5]:
net = NeuralNetClassifier(
    MyModule,
    max_epochs=50,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

In [6]:
net.fit(X, y)
y_proba = net.predict_proba(X)
print(y_proba, "is calculated probability.")

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7105[0m       [32m0.4900[0m        [35m0.6951[0m  0.0105
      2        [36m0.6940[0m       [32m0.5350[0m        [35m0.6838[0m  0.0077
      3        [36m0.6849[0m       [32m0.5750[0m        [35m0.6764[0m  0.0072
      4        [36m0.6697[0m       [32m0.6200[0m        [35m0.6714[0m  0.0074
      5        [36m0.6696[0m       [32m0.6300[0m        [35m0.6664[0m  0.0074
      6        [36m0.6469[0m       [32m0.6550[0m        [35m0.6583[0m  0.0077
      7        [36m0.6359[0m       0.6450        [35m0.6501[0m  0.0075
      8        [36m0.6258[0m       [32m0.6650[0m        [35m0.6410[0m  0.0076
      9        [36m0.6132[0m       0.6500        [35m0.6285[0m  0.0068
     10        0.6221       [32m0.6750[0m        [35m0.6216[0m  0.0080
     11        [36m0.6043[0m       [32m0.6900[0m        [35m0.6127[