In [1]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch import optim
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import tqdm
from copy import deepcopy

In [2]:
data = load_iris()
X = data['data']
y = [[t] for t in data['target']]
ohe = OneHotEncoder(handle_unknown='ignore', sparse_output=False).fit(y)
y = ohe.transform(y)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

In [3]:
xtr, xts, ytr, yts = train_test_split(X, y, train_size=0.7, shuffle=True)

In [4]:
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.output(x)
        return x

model = Multiclass()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
def train(model, xtr, ytr, xts, yts, n_epochs = 100, batch_size = 5):
    batches_per_epoch = len(xtr) // batch_size

    best_acc = - np.inf
    best_weights = None
    train_loss_hist = []
    train_acc_hist = []
    test_loss_hist = []
    test_acc_hist = []

    for epoch in range(n_epochs):
        epoch_loss = []
        epoch_acc = []
        # set model in training mode and run through each batch
        model.train()
        with tqdm.trange(batches_per_epoch, unit="batch", mininterval=0) as bar:
            bar.set_description(f"Epoch {epoch}")
            for i in bar:
                # take a batch
                start = i * batch_size
                X_batch = xtr[start:start+batch_size]
                y_batch = ytr[start:start+batch_size]
                # forward pass
                y_pred = model(X_batch)
                loss = loss_fn(y_pred, y_batch)
                # backward pass
                optimizer.zero_grad()
                loss.backward()
                # update weights
                optimizer.step()
                # compute and store metrics
                acc = (torch.argmax(y_pred, 1) == torch.argmax(y_batch, 1)).float().mean()
                epoch_loss.append(float(loss))
                epoch_acc.append(float(acc))
                bar.set_postfix(
                    loss=float(loss),
                    acc=float(acc)
                )
        # set model in evaluation mode and run through the test set
        model.eval()
        y_pred = model(xts)
        ce = loss_fn(y_pred, yts)
        acc = (torch.argmax(y_pred, 1) == torch.argmax(yts, 1)).float().mean()
        ce = float(ce)
        acc = float(acc)
        train_loss_hist.append(np.mean(epoch_loss))
        train_acc_hist.append(np.mean(epoch_acc))
        test_loss_hist.append(ce)
        test_acc_hist.append(acc)
        if acc > best_acc:
            best_acc = acc
            best_weights = deepcopy(model.state_dict())
        print(f"Epoch {epoch} validation: Cross-entropy={ce:.2f}, Accuracy={acc*100:.1f}%")




In [6]:
train(model, xtr, ytr, xts, yts)

Epoch 0: 100%|██████████| 21/21 [00:00<00:00, 155.94batch/s, acc=0.2, loss=1.16]


Epoch 0 validation: Cross-entropy=1.14, Accuracy=24.4%


Epoch 1: 100%|██████████| 21/21 [00:00<00:00, 226.84batch/s, acc=0.2, loss=1.15]


Epoch 1 validation: Cross-entropy=1.13, Accuracy=22.2%


Epoch 2: 100%|██████████| 21/21 [00:00<00:00, 233.10batch/s, acc=0, loss=1.14]


Epoch 2 validation: Cross-entropy=1.13, Accuracy=8.9%


Epoch 3: 100%|██████████| 21/21 [00:00<00:00, 233.14batch/s, acc=0, loss=1.13]


Epoch 3 validation: Cross-entropy=1.12, Accuracy=0.0%


Epoch 4: 100%|██████████| 21/21 [00:00<00:00, 231.02batch/s, acc=0, loss=1.13]


Epoch 4 validation: Cross-entropy=1.12, Accuracy=2.2%


Epoch 5: 100%|██████████| 21/21 [00:00<00:00, 250.59batch/s, acc=0.4, loss=1.12]


Epoch 5 validation: Cross-entropy=1.11, Accuracy=40.0%


Epoch 6: 100%|██████████| 21/21 [00:00<00:00, 201.36batch/s, acc=0.4, loss=1.12]


Epoch 6 validation: Cross-entropy=1.11, Accuracy=40.0%


Epoch 7: 100%|██████████| 21/21 [00:00<00:00, 237.31batch/s, acc=0.4, loss=1.11]


Epoch 7 validation: Cross-entropy=1.10, Accuracy=40.0%


Epoch 8: 100%|██████████| 21/21 [00:00<00:00, 243.97batch/s, acc=0.4, loss=1.1]


Epoch 8 validation: Cross-entropy=1.10, Accuracy=40.0%


Epoch 9: 100%|██████████| 21/21 [00:00<00:00, 101.65batch/s, acc=0.4, loss=1.09]


Epoch 9 validation: Cross-entropy=1.09, Accuracy=40.0%


Epoch 10: 100%|██████████| 21/21 [00:00<00:00, 220.35batch/s, acc=0.4, loss=1.08]


Epoch 10 validation: Cross-entropy=1.08, Accuracy=40.0%


Epoch 11: 100%|██████████| 21/21 [00:00<00:00, 209.06batch/s, acc=0.4, loss=1.07]


Epoch 11 validation: Cross-entropy=1.07, Accuracy=40.0%


Epoch 12: 100%|██████████| 21/21 [00:00<00:00, 217.51batch/s, acc=0.4, loss=1.06]


Epoch 12 validation: Cross-entropy=1.06, Accuracy=40.0%


Epoch 13: 100%|██████████| 21/21 [00:00<00:00, 209.54batch/s, acc=0.4, loss=1.05]


Epoch 13 validation: Cross-entropy=1.05, Accuracy=40.0%


Epoch 14: 100%|██████████| 21/21 [00:00<00:00, 58.12batch/s, acc=0.4, loss=1.04]


Epoch 14 validation: Cross-entropy=1.04, Accuracy=40.0%


Epoch 15: 100%|██████████| 21/21 [00:00<00:00, 119.52batch/s, acc=0.4, loss=1.03]


Epoch 15 validation: Cross-entropy=1.03, Accuracy=40.0%


Epoch 16: 100%|██████████| 21/21 [00:00<00:00, 135.78batch/s, acc=0.4, loss=1.01]


Epoch 16 validation: Cross-entropy=1.02, Accuracy=40.0%


Epoch 17: 100%|██████████| 21/21 [00:00<00:00, 121.37batch/s, acc=0.4, loss=1]


Epoch 17 validation: Cross-entropy=1.01, Accuracy=42.2%


Epoch 18: 100%|██████████| 21/21 [00:00<00:00, 99.36batch/s, acc=0.4, loss=0.988] 


Epoch 18 validation: Cross-entropy=0.99, Accuracy=42.2%


Epoch 19: 100%|██████████| 21/21 [00:00<00:00, 26.97batch/s, acc=0.4, loss=0.976]


Epoch 19 validation: Cross-entropy=0.98, Accuracy=48.9%


Epoch 20: 100%|██████████| 21/21 [00:00<00:00, 223.47batch/s, acc=0.6, loss=0.963]


Epoch 20 validation: Cross-entropy=0.97, Accuracy=57.8%


Epoch 21: 100%|██████████| 21/21 [00:00<00:00, 103.01batch/s, acc=0.6, loss=0.949]


Epoch 21 validation: Cross-entropy=0.95, Accuracy=62.2%


Epoch 22: 100%|██████████| 21/21 [00:00<00:00, 62.88batch/s, acc=0.6, loss=0.936]


Epoch 22 validation: Cross-entropy=0.94, Accuracy=64.4%


Epoch 23: 100%|██████████| 21/21 [00:00<00:00, 78.40batch/s, acc=0.6, loss=0.922]


Epoch 23 validation: Cross-entropy=0.93, Accuracy=64.4%


Epoch 24: 100%|██████████| 21/21 [00:00<00:00, 120.86batch/s, acc=0.6, loss=0.907]


Epoch 24 validation: Cross-entropy=0.91, Accuracy=64.4%


Epoch 25: 100%|██████████| 21/21 [00:00<00:00, 54.68batch/s, acc=0.6, loss=0.891] 


Epoch 25 validation: Cross-entropy=0.89, Accuracy=64.4%


Epoch 26: 100%|██████████| 21/21 [00:00<00:00, 47.09batch/s, acc=0.6, loss=0.875] 


Epoch 26 validation: Cross-entropy=0.88, Accuracy=64.4%


Epoch 27: 100%|██████████| 21/21 [00:00<00:00, 58.15batch/s, acc=0.6, loss=0.861]


Epoch 27 validation: Cross-entropy=0.86, Accuracy=64.4%


Epoch 28: 100%|██████████| 21/21 [00:00<00:00, 110.56batch/s, acc=0.6, loss=0.847]


Epoch 28 validation: Cross-entropy=0.85, Accuracy=64.4%


Epoch 29: 100%|██████████| 21/21 [00:00<00:00, 110.51batch/s, acc=0.6, loss=0.835]


Epoch 29 validation: Cross-entropy=0.84, Accuracy=64.4%


Epoch 30: 100%|██████████| 21/21 [00:00<00:00, 58.65batch/s, acc=0.6, loss=0.823]


Epoch 30 validation: Cross-entropy=0.83, Accuracy=64.4%


Epoch 31: 100%|██████████| 21/21 [00:00<00:00, 110.84batch/s, acc=0.6, loss=0.813]


Epoch 31 validation: Cross-entropy=0.82, Accuracy=64.4%


Epoch 32: 100%|██████████| 21/21 [00:00<00:00, 104.05batch/s, acc=0.6, loss=0.803]


Epoch 32 validation: Cross-entropy=0.81, Accuracy=64.4%


Epoch 33: 100%|██████████| 21/21 [00:00<00:00, 166.49batch/s, acc=0.6, loss=0.794]


Epoch 33 validation: Cross-entropy=0.80, Accuracy=64.4%


Epoch 34: 100%|██████████| 21/21 [00:00<00:00, 135.06batch/s, acc=0.6, loss=0.785]


Epoch 34 validation: Cross-entropy=0.79, Accuracy=64.4%


Epoch 35: 100%|██████████| 21/21 [00:00<00:00, 115.75batch/s, acc=0.6, loss=0.777]


Epoch 35 validation: Cross-entropy=0.78, Accuracy=64.4%


Epoch 36: 100%|██████████| 21/21 [00:00<00:00, 93.17batch/s, acc=0.6, loss=0.77]


Epoch 36 validation: Cross-entropy=0.77, Accuracy=64.4%


Epoch 37: 100%|██████████| 21/21 [00:00<00:00, 115.20batch/s, acc=0.6, loss=0.763]


Epoch 37 validation: Cross-entropy=0.76, Accuracy=64.4%


Epoch 38: 100%|██████████| 21/21 [00:00<00:00, 147.03batch/s, acc=0.6, loss=0.756]


Epoch 38 validation: Cross-entropy=0.76, Accuracy=64.4%


Epoch 39: 100%|██████████| 21/21 [00:00<00:00, 97.33batch/s, acc=0.6, loss=0.749] 


Epoch 39 validation: Cross-entropy=0.75, Accuracy=64.4%


Epoch 40: 100%|██████████| 21/21 [00:00<00:00, 162.75batch/s, acc=0.6, loss=0.743]


Epoch 40 validation: Cross-entropy=0.74, Accuracy=64.4%


Epoch 41: 100%|██████████| 21/21 [00:00<00:00, 29.71batch/s, acc=0.6, loss=0.736] 


Epoch 41 validation: Cross-entropy=0.74, Accuracy=64.4%


Epoch 42: 100%|██████████| 21/21 [00:00<00:00, 126.65batch/s, acc=0.6, loss=0.73]


Epoch 42 validation: Cross-entropy=0.73, Accuracy=64.4%


Epoch 43: 100%|██████████| 21/21 [00:00<00:00, 28.35batch/s, acc=0.6, loss=0.724]


Epoch 43 validation: Cross-entropy=0.72, Accuracy=64.4%


Epoch 44: 100%|██████████| 21/21 [00:00<00:00, 79.02batch/s, acc=0.6, loss=0.719]


Epoch 44 validation: Cross-entropy=0.72, Accuracy=64.4%


Epoch 45: 100%|██████████| 21/21 [00:00<00:00, 113.83batch/s, acc=0.6, loss=0.714]


Epoch 45 validation: Cross-entropy=0.71, Accuracy=64.4%


Epoch 46: 100%|██████████| 21/21 [00:00<00:00, 117.40batch/s, acc=0.6, loss=0.709]


Epoch 46 validation: Cross-entropy=0.71, Accuracy=64.4%


Epoch 47: 100%|██████████| 21/21 [00:00<00:00, 145.58batch/s, acc=0.6, loss=0.705]


Epoch 47 validation: Cross-entropy=0.70, Accuracy=64.4%


Epoch 48: 100%|██████████| 21/21 [00:00<00:00, 117.26batch/s, acc=0.6, loss=0.701]


Epoch 48 validation: Cross-entropy=0.70, Accuracy=64.4%


Epoch 49: 100%|██████████| 21/21 [00:00<00:00, 100.20batch/s, acc=0.6, loss=0.697]


Epoch 49 validation: Cross-entropy=0.69, Accuracy=64.4%


Epoch 50: 100%|██████████| 21/21 [00:00<00:00, 102.19batch/s, acc=0.6, loss=0.693]


Epoch 50 validation: Cross-entropy=0.69, Accuracy=64.4%


Epoch 51: 100%|██████████| 21/21 [00:00<00:00, 62.38batch/s, acc=0.6, loss=0.688]


Epoch 51 validation: Cross-entropy=0.68, Accuracy=64.4%


Epoch 52: 100%|██████████| 21/21 [00:00<00:00, 87.29batch/s, acc=0.6, loss=0.685]


Epoch 52 validation: Cross-entropy=0.68, Accuracy=64.4%


Epoch 53: 100%|██████████| 21/21 [00:00<00:00, 131.08batch/s, acc=0.6, loss=0.681]


Epoch 53 validation: Cross-entropy=0.67, Accuracy=64.4%


Epoch 54: 100%|██████████| 21/21 [00:00<00:00, 90.67batch/s, acc=0.6, loss=0.677]


Epoch 54 validation: Cross-entropy=0.67, Accuracy=64.4%


Epoch 55: 100%|██████████| 21/21 [00:00<00:00, 82.88batch/s, acc=0.6, loss=0.674] 


Epoch 55 validation: Cross-entropy=0.67, Accuracy=64.4%


Epoch 56: 100%|██████████| 21/21 [00:00<00:00, 30.81batch/s, acc=0.6, loss=0.67]


Epoch 56 validation: Cross-entropy=0.66, Accuracy=64.4%


Epoch 57: 100%|██████████| 21/21 [00:00<00:00, 116.81batch/s, acc=0.6, loss=0.667]


Epoch 57 validation: Cross-entropy=0.66, Accuracy=64.4%


Epoch 58: 100%|██████████| 21/21 [00:00<00:00, 138.03batch/s, acc=0.6, loss=0.664]


Epoch 58 validation: Cross-entropy=0.66, Accuracy=64.4%


Epoch 59: 100%|██████████| 21/21 [00:00<00:00, 143.96batch/s, acc=0.6, loss=0.661]


Epoch 59 validation: Cross-entropy=0.65, Accuracy=64.4%


Epoch 60: 100%|██████████| 21/21 [00:00<00:00, 149.13batch/s, acc=0.6, loss=0.658]


Epoch 60 validation: Cross-entropy=0.65, Accuracy=64.4%


Epoch 61: 100%|██████████| 21/21 [00:00<00:00, 132.23batch/s, acc=0.6, loss=0.655]


Epoch 61 validation: Cross-entropy=0.64, Accuracy=64.4%


Epoch 62: 100%|██████████| 21/21 [00:00<00:00, 35.67batch/s, acc=0.6, loss=0.652] 


Epoch 62 validation: Cross-entropy=0.64, Accuracy=64.4%


Epoch 63: 100%|██████████| 21/21 [00:00<00:00, 121.54batch/s, acc=0.6, loss=0.649]


Epoch 63 validation: Cross-entropy=0.64, Accuracy=64.4%


Epoch 64: 100%|██████████| 21/21 [00:00<00:00, 125.77batch/s, acc=0.6, loss=0.646]


Epoch 64 validation: Cross-entropy=0.63, Accuracy=64.4%


Epoch 65: 100%|██████████| 21/21 [00:00<00:00, 86.96batch/s, acc=0.6, loss=0.644]


Epoch 65 validation: Cross-entropy=0.63, Accuracy=64.4%


Epoch 66: 100%|██████████| 21/21 [00:00<00:00, 106.31batch/s, acc=0.6, loss=0.641]


Epoch 66 validation: Cross-entropy=0.63, Accuracy=64.4%


Epoch 67: 100%|██████████| 21/21 [00:00<00:00, 28.53batch/s, acc=0.6, loss=0.638]


Epoch 67 validation: Cross-entropy=0.62, Accuracy=64.4%


Epoch 68: 100%|██████████| 21/21 [00:00<00:00, 81.74batch/s, acc=0.6, loss=0.636] 


Epoch 68 validation: Cross-entropy=0.62, Accuracy=64.4%


Epoch 69: 100%|██████████| 21/21 [00:00<00:00, 109.74batch/s, acc=0.6, loss=0.633]


Epoch 69 validation: Cross-entropy=0.62, Accuracy=64.4%


Epoch 70: 100%|██████████| 21/21 [00:00<00:00, 76.17batch/s, acc=0.6, loss=0.631]


Epoch 70 validation: Cross-entropy=0.62, Accuracy=64.4%


Epoch 71: 100%|██████████| 21/21 [00:00<00:00, 100.18batch/s, acc=0.6, loss=0.628]


Epoch 71 validation: Cross-entropy=0.61, Accuracy=64.4%


Epoch 72: 100%|██████████| 21/21 [00:00<00:00, 51.82batch/s, acc=0.6, loss=0.626]


Epoch 72 validation: Cross-entropy=0.61, Accuracy=64.4%


Epoch 73: 100%|██████████| 21/21 [00:00<00:00, 52.17batch/s, acc=0.6, loss=0.624]


Epoch 73 validation: Cross-entropy=0.61, Accuracy=64.4%


Epoch 74: 100%|██████████| 21/21 [00:00<00:00, 88.81batch/s, acc=0.6, loss=0.621]


Epoch 74 validation: Cross-entropy=0.60, Accuracy=64.4%


Epoch 75: 100%|██████████| 21/21 [00:00<00:00, 74.95batch/s, acc=0.6, loss=0.619]


Epoch 75 validation: Cross-entropy=0.60, Accuracy=64.4%


Epoch 76: 100%|██████████| 21/21 [00:00<00:00, 25.65batch/s, acc=0.6, loss=0.617]


Epoch 76 validation: Cross-entropy=0.60, Accuracy=64.4%


Epoch 77: 100%|██████████| 21/21 [00:00<00:00, 55.09batch/s, acc=0.6, loss=0.615]


Epoch 77 validation: Cross-entropy=0.60, Accuracy=64.4%


Epoch 78: 100%|██████████| 21/21 [00:00<00:00, 65.10batch/s, acc=0.6, loss=0.613]


Epoch 78 validation: Cross-entropy=0.59, Accuracy=64.4%


Epoch 79: 100%|██████████| 21/21 [00:00<00:00, 76.82batch/s, acc=0.6, loss=0.611] 


Epoch 79 validation: Cross-entropy=0.59, Accuracy=66.7%


Epoch 80: 100%|██████████| 21/21 [00:00<00:00, 94.20batch/s, acc=0.6, loss=0.609] 


Epoch 80 validation: Cross-entropy=0.59, Accuracy=66.7%


Epoch 81: 100%|██████████| 21/21 [00:00<00:00, 34.38batch/s, acc=0.6, loss=0.606]


Epoch 81 validation: Cross-entropy=0.59, Accuracy=66.7%


Epoch 82: 100%|██████████| 21/21 [00:00<00:00, 63.98batch/s, acc=0.6, loss=0.604]


Epoch 82 validation: Cross-entropy=0.58, Accuracy=66.7%


Epoch 83: 100%|██████████| 21/21 [00:00<00:00, 62.16batch/s, acc=0.6, loss=0.602]


Epoch 83 validation: Cross-entropy=0.58, Accuracy=66.7%


Epoch 84: 100%|██████████| 21/21 [00:00<00:00, 70.01batch/s, acc=0.6, loss=0.6]


Epoch 84 validation: Cross-entropy=0.58, Accuracy=66.7%


Epoch 85: 100%|██████████| 21/21 [00:00<00:00, 134.37batch/s, acc=0.6, loss=0.599]


Epoch 85 validation: Cross-entropy=0.58, Accuracy=66.7%


Epoch 86: 100%|██████████| 21/21 [00:00<00:00, 103.89batch/s, acc=0.6, loss=0.597]


Epoch 86 validation: Cross-entropy=0.57, Accuracy=66.7%


Epoch 87: 100%|██████████| 21/21 [00:00<00:00, 168.30batch/s, acc=0.6, loss=0.595]


Epoch 87 validation: Cross-entropy=0.57, Accuracy=66.7%


Epoch 88: 100%|██████████| 21/21 [00:00<00:00, 68.12batch/s, acc=0.6, loss=0.593]


Epoch 88 validation: Cross-entropy=0.57, Accuracy=66.7%


Epoch 89: 100%|██████████| 21/21 [00:00<00:00, 92.12batch/s, acc=0.6, loss=0.591]


Epoch 89 validation: Cross-entropy=0.57, Accuracy=66.7%


Epoch 90: 100%|██████████| 21/21 [00:00<00:00, 97.77batch/s, acc=0.6, loss=0.589]


Epoch 90 validation: Cross-entropy=0.57, Accuracy=66.7%


Epoch 91: 100%|██████████| 21/21 [00:00<00:00, 139.00batch/s, acc=0.6, loss=0.588]


Epoch 91 validation: Cross-entropy=0.56, Accuracy=66.7%


Epoch 92: 100%|██████████| 21/21 [00:00<00:00, 57.37batch/s, acc=0.6, loss=0.586]


Epoch 92 validation: Cross-entropy=0.56, Accuracy=66.7%


Epoch 93: 100%|██████████| 21/21 [00:00<00:00, 26.92batch/s, acc=0.6, loss=0.584]


Epoch 93 validation: Cross-entropy=0.56, Accuracy=68.9%


Epoch 94: 100%|██████████| 21/21 [00:00<00:00, 108.61batch/s, acc=0.6, loss=0.583]


Epoch 94 validation: Cross-entropy=0.56, Accuracy=68.9%


Epoch 95: 100%|██████████| 21/21 [00:00<00:00, 91.63batch/s, acc=0.6, loss=0.581]


Epoch 95 validation: Cross-entropy=0.56, Accuracy=68.9%


Epoch 96: 100%|██████████| 21/21 [00:00<00:00, 21.57batch/s, acc=0.6, loss=0.579]


Epoch 96 validation: Cross-entropy=0.55, Accuracy=68.9%


Epoch 97: 100%|██████████| 21/21 [00:00<00:00, 34.32batch/s, acc=0.6, loss=0.578]


Epoch 97 validation: Cross-entropy=0.55, Accuracy=68.9%


Epoch 98: 100%|██████████| 21/21 [00:00<00:00, 82.70batch/s, acc=0.6, loss=0.576]


Epoch 98 validation: Cross-entropy=0.55, Accuracy=68.9%


Epoch 99: 100%|██████████| 21/21 [00:00<00:00, 64.22batch/s, acc=0.6, loss=0.575]

Epoch 99 validation: Cross-entropy=0.55, Accuracy=68.9%



