In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [3]:
iris = load_iris()
X = iris['data']
y = iris['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1.0/3.0, random_state=1)
X_train_norm = (X_train - np.mean(X_train)) / np.std(X_train)
X_train_norm = torch.from_numpy(X_train_norm).float()
y_train = torch.from_numpy(y_train)
# Special Note: There seems to be a mistake in the book and github code
# y_train has to be transformed to torch.long for the loss function
# to work accurately.
y_train = y_train.to(dtype=torch.long)
train_ds = TensorDataset(X_train_norm, y_train)
torch.manual_seed(1)
BATCH_SIZE = 2
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [4]:
# Constants
LEARNING_RATE = 0.001
NUMBER_OF_EPOCHS = 100
INPUT_SIZE = X_train_norm.shape[1]
HIDDEN_SIZE = 16
OUTPUT_SIZE = 3

In [5]:
class Model(nn.Module):
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.layer1(x)
        x = nn.Sigmoid()(x)
        x = self.layer2(x)
        x = nn.Softmax(dim=1)(x)
        return x

In [6]:
path = './model_state/iris_classifier_state.pt'
model = Model(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, output_size=OUTPUT_SIZE)
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [7]:
X_test_norm = (X_test - np.mean(X_test)) / np.std(X_test)
X_test_norm = torch.from_numpy(X_test_norm).float()
y_test = torch.from_numpy(y_test)
softmax_multiclass_probabilities = model(X_test_norm)

In [8]:
# Evaluating the loaded model on test set
softmax_multiclass_probabilities = model(X_test_norm)
correct = (torch.argmax(softmax_multiclass_probabilities, dim=1) == y_test).float()
accuracy = correct.mean()
print(f'Model Accuracy on Test Set: {accuracy*100:.2f}%')

Model Accuracy on Test Set: 98.00%
