In [1]:
import torch
import numpy as np
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
X = (X - np.mean(X))/np.std(X)
Y = iris.target

## Classification
Response: Species -> 3 classes

In [2]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.Net = torch.nn.Sequential(
            torch.nn.Linear(4, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 3)
        )
    def forward(self, x):
        return self.Net(x)

NN = Net()

dataset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.long))
DT = torch.utils.data.DataLoader(dataset, batch_size=20, shuffle=True)

optim = torch.optim.Adamax(NN.parameters())

for e in range(100):
    for x, y in DT:
        optim.zero_grad()
        pred = NN.forward(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        loss.backward()
        optim.step()
    if e%10 == 0:
        print(loss.item())


1.1210181713104248
1.0183498859405518
0.7089935541152954
0.39600253105163574
0.28598326444625854
0.2722209095954895
0.3285526633262634
0.22217264771461487
0.07098357379436493
0.10444505512714386


Make predictions:

In [3]:
pred = NN.forward(torch.tensor(X, dtype=torch.float32))

Probabilities for each species:

In [4]:
print(torch.nn.functional.softmax(pred,1))

tensor([[9.8932e-01, 1.0682e-02, 9.9086e-08],
        [9.8605e-01, 1.3953e-02, 1.2701e-07],
        [9.8959e-01, 1.0407e-02, 1.0176e-07],
        [9.8699e-01, 1.3009e-02, 1.7625e-07],
        [9.8945e-01, 1.0552e-02, 1.1009e-07],
        [9.8478e-01, 1.5224e-02, 2.8323e-07],
        [9.8788e-01, 1.2121e-02, 1.8922e-07],
        [9.8819e-01, 1.1807e-02, 1.3034e-07],
        [9.8760e-01, 1.2396e-02, 1.5807e-07],
        [9.8593e-01, 1.4066e-02, 1.2858e-07],
        [9.8832e-01, 1.1679e-02, 1.0682e-07],
        [9.8712e-01, 1.2879e-02, 1.9141e-07],
        [9.8763e-01, 1.2373e-02, 1.0152e-07],
        [9.9084e-01, 9.1580e-03, 7.0760e-08],
        [9.9147e-01, 8.5320e-03, 4.2687e-08],
        [9.8841e-01, 1.1588e-02, 1.5844e-07],
        [9.8932e-01, 1.0677e-02, 1.1469e-07],
        [9.8850e-01, 1.1495e-02, 1.2736e-07],
        [9.8265e-01, 1.7352e-02, 2.0907e-07],
        [9.8807e-01, 1.1930e-02, 1.7094e-07],
        [9.7910e-01, 2.0902e-02, 2.3527e-07],
        [9.8704e-01, 1.2963e-02, 2

## Regression
Response: Sepal.Length -> continuous

In [5]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.Net = torch.nn.Sequential(
            torch.nn.Linear(3, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 1)
        )
    def forward(self, x):
        return self.Net(x)

NN = Net()

dataset = torch.utils.data.TensorDataset(torch.tensor(X[:,1:], dtype=torch.float32), torch.tensor(X[:,[0]], dtype=torch.float32))
DT = torch.utils.data.DataLoader(dataset, batch_size=20, shuffle=True)

optim = torch.optim.Adamax(NN.parameters())

for e in range(100):
    for x, y in DT:
        optim.zero_grad()
        pred = NN.forward(x)
        loss = torch.nn.functional.mse_loss(pred, y)
        loss.backward()
        optim.step()
    if e%10 == 0:
        print(loss.item())


1.2916386127471924
0.30858591198921204
0.0590483620762825
0.039184268563985825
0.039229974150657654
0.03206280618906021
0.035786859691143036
0.02563987299799919
0.03576971963047981
0.028240865096449852


Make predictions:

In [6]:
pred = NN.forward(torch.tensor(X[:,1:], dtype=torch.float32))

In [7]:
pred

tensor([[0.7713],
        [0.7427],
        [0.7393],
        [0.7609],
        [0.7803],
        [0.8094],
        [0.7481],
        [0.7758],
        [0.7377],
        [0.7773],
        [0.8001],
        [0.7891],
        [0.7590],
        [0.7191],
        [0.7851],
        [0.8028],
        [0.7647],
        [0.7571],
        [0.8163],
        [0.7948],
        [0.8024],
        [0.7715],
        [0.7370],
        [0.7485],
        [0.8291],
        [0.7693],
        [0.7565],
        [0.7821],
        [0.7625],
        [0.7792],
        [0.7743],
        [0.7446],
        [0.8353],
        [0.8146],
        [0.7609],
        [0.7260],
        [0.7606],
        [0.7946],
        [0.7294],
        [0.7758],
        [0.7463],
        [0.6783],
        [0.7393],
        [0.7357],
        [0.8236],
        [0.7263],
        [0.8198],
        [0.7526],
        [0.8001],
        [0.7575],
        [1.4096],
        [1.3335],
        [1.4601],
        [1.0845],
        [1.3266],
        [1