In [24]:
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [25]:
iris = load_iris()

In [26]:
X, Y = iris.data, iris.target

In [27]:
X.shape

(150, 4)

In [28]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 200)

In [29]:
Y_train

array([0, 1, 1, 0, 1, 0, 0, 0, 2, 1, 1, 0, 0, 1, 2, 2, 2, 0, 2, 1, 2, 1,
       1, 1, 2, 0, 1, 1, 1, 2, 0, 0, 2, 0, 0, 2, 1, 0, 0, 2, 0, 0, 2, 2,
       1, 1, 2, 0, 0, 2, 1, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 1, 0, 2, 1, 2,
       1, 2, 1, 2, 2, 1, 1, 2, 1, 0, 2, 2, 1, 1, 0, 0, 2, 0, 1, 2, 1, 0,
       0, 0, 2, 2, 0, 1, 2, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0, 0, 1, 2, 2, 2,
       2, 0, 1, 1, 1, 0, 1, 0, 2, 0])

In [30]:
scaler = StandardScaler()

In [31]:
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

In [32]:
type(X_train)

numpy.ndarray

In [33]:
X_train_tensor = torch.tensor(X_train, dtype = torch.float32)
X_test_tensor = torch.tensor(X_test, dtype = torch.float32)
Y_train_tensor = torch.tensor(Y_train, dtype = torch.long)
Y_test_tensor = torch.tensor(Y_test, dtype = torch.long)

In [34]:
type(X_train_tensor)

torch.Tensor

In [35]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(4, 20)
        self.layer2 = nn.Linear(20, 3)
        self.output = nn.Softmax()

    def forward(self, X):
      X = self.layer1(X)
      X = self.layer2(X)
      X = self.output(X)
      return X

In [36]:
model = Model()

In [37]:
loss = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr = 0.01)

In [38]:
epochs = 1000

for epoch in range(epochs):
    optim.zero_grad()
    forward_output = model(X_train_tensor)
    cost = loss(forward_output, Y_train_tensor)
    cost.backward()
    optim.step()

    print(f'Epoch: {epoch}, Loss: {cost}')

  return self._call_impl(*args, **kwargs)


Epoch: 0, Loss: 1.0802052021026611
Epoch: 1, Loss: 1.0392463207244873
Epoch: 2, Loss: 0.9996952414512634
Epoch: 3, Loss: 0.9626389741897583
Epoch: 4, Loss: 0.9292826056480408
Epoch: 5, Loss: 0.9004572629928589
Epoch: 6, Loss: 0.8764475584030151
Epoch: 7, Loss: 0.857020378112793
Epoch: 8, Loss: 0.8415334224700928
Epoch: 9, Loss: 0.8291209936141968
Epoch: 10, Loss: 0.8189296126365662
Epoch: 11, Loss: 0.8102607131004333
Epoch: 12, Loss: 0.8025936484336853
Epoch: 13, Loss: 0.795559287071228
Epoch: 14, Loss: 0.7889086604118347
Epoch: 15, Loss: 0.7824851274490356
Epoch: 16, Loss: 0.7762024998664856
Epoch: 17, Loss: 0.7700262069702148
Epoch: 18, Loss: 0.7639579176902771
Epoch: 19, Loss: 0.7580220699310303
Epoch: 20, Loss: 0.7522551417350769
Epoch: 21, Loss: 0.7466951012611389
Epoch: 22, Loss: 0.7413745522499084
Epoch: 23, Loss: 0.7363137006759644
Epoch: 24, Loss: 0.731516420841217
Epoch: 25, Loss: 0.7269681096076965
Epoch: 26, Loss: 0.7226382493972778
Epoch: 27, Loss: 0.7184856534004211
Epoch

In [39]:
model_output = model(X_test_tensor)
print(model_output)

tensor([[2.9377e-04, 9.9887e-01, 8.4108e-04],
        [0.0000e+00, 8.5948e-30, 1.0000e+00],
        [9.9943e-01, 5.6781e-04, 0.0000e+00],
        [1.0000e+00, 1.2679e-06, 0.0000e+00],
        [4.6640e-10, 1.0000e+00, 2.5754e-12],
        [1.4600e-36, 7.3745e-25, 1.0000e+00],
        [1.8483e-07, 1.0000e+00, 2.3803e-20],
        [5.1325e-06, 9.9999e-01, 8.3483e-29],
        [7.5909e-08, 1.0000e+00, 2.0919e-17],
        [1.8411e-38, 3.7944e-29, 1.0000e+00],
        [8.1836e-27, 4.2204e-17, 1.0000e+00],
        [9.9995e-01, 5.2043e-05, 0.0000e+00],
        [9.9833e-01, 1.6665e-03, 0.0000e+00],
        [1.0000e+00, 6.5308e-07, 0.0000e+00],
        [1.4220e-31, 9.7897e-20, 1.0000e+00],
        [9.9999e-01, 6.4175e-06, 0.0000e+00],
        [4.5929e-06, 1.0000e+00, 8.2846e-24],
        [2.5219e-02, 9.7478e-01, 0.0000e+00],
        [7.7454e-08, 9.9989e-01, 1.1385e-04],
        [1.0880e-20, 2.2226e-12, 1.0000e+00],
        [1.0000e+00, 6.0544e-07, 0.0000e+00],
        [5.8251e-08, 1.0000e+00, 6

  return self._call_impl(*args, **kwargs)
