In [51]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch import optim
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import tqdm
from copy import deepcopy

In [52]:
data = load_iris()
X = data['data']
y = [[t] for t in data['target']]
ohe = OneHotEncoder(handle_unknown='ignore', sparse_output=False).fit(y)
y = ohe.transform(y)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

In [53]:
xtr, xts, ytr, yts = train_test_split(X, y, train_size=0.7, shuffle=True)

In [54]:
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.output(x)
        return x

In [55]:
def get_batched(data, batch_size):
    if batch_size is None:
        return data
    dc = deepcopy(data)
    while dc.numel():
        batch, dc = dc[:batch_size], dc[batch_size:]
        yield batch

def format_stats(loss, acc):
    return f"Loss={loss:.2f}, Accuracy={acc*100:.1f}%"
        
def train(model, xtr, ytr, batch_size = None):
    cum_loss, cum_acc = [], []
    for xtrb, ytrb in zip(get_batched(xtr, batch_size), get_batched(ytr, batch_size)):
        ypred = model(xtrb)
        loss = loss_fn(ypred, ytrb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (torch.argmax(ypred, 1) == torch.argmax(ytrb, 1)).float().mean()
        cum_loss.append(float(loss))
        cum_acc.append(float(acc))
    return cum_loss, cum_acc

def test(model, xts, yts):
    ypred = model(xts)
    loss = loss_fn(ypred, yts)
    acc = (torch.argmax(ypred, 1) == torch.argmax(xts, 1)).float().mean()
    return float(loss), float(acc)


def train_epochs(model, xtr, ytr, xts, yts, n_epochs = 100, batch_size = 5):
    best_acc = - np.inf
    best_weights = None
    train_loss_hist = []
    train_acc_hist = []
    
    for epoch in range(n_epochs):
        epoch_loss = []
        epoch_acc = []
        model.train()
        cum_loss, cum_acc = train(model, xtr, ytr, batch_size)
        model.eval()
        mloss, macc = np.mean(cum_loss), np.mean(cum_acc)
        train_loss_hist.append(mloss)
        train_acc_hist.append(macc)
        if macc > best_acc:
            best_acc = macc
            best_weights = deepcopy(model.state_dict())
        print(f"Epoch {epoch} validation: Cross-entropy={mloss:.2f}, Accuracy={macc*100:.1f}%")
    model.load_state_dict(best_weights)




In [56]:
model = Multiclass()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

pre_loss, pre_acc = test(model, xts, yts)
train_epochs(model, xtr, ytr, xts, yts)
post_loss, post_acc = test(model, xts, yts)
format_stats(pre_loss, pre_acc), format_stats(post_loss, post_acc)

Epoch 0 validation: Cross-entropy=1.28, Accuracy=32.4%
Epoch 1 validation: Cross-entropy=1.15, Accuracy=32.4%
Epoch 2 validation: Cross-entropy=1.07, Accuracy=37.1%
Epoch 3 validation: Cross-entropy=1.01, Accuracy=58.1%
Epoch 4 validation: Cross-entropy=0.97, Accuracy=65.7%
Epoch 5 validation: Cross-entropy=0.94, Accuracy=65.7%
Epoch 6 validation: Cross-entropy=0.92, Accuracy=65.7%
Epoch 7 validation: Cross-entropy=0.89, Accuracy=65.7%
Epoch 8 validation: Cross-entropy=0.87, Accuracy=65.7%
Epoch 9 validation: Cross-entropy=0.85, Accuracy=65.7%
Epoch 10 validation: Cross-entropy=0.83, Accuracy=65.7%
Epoch 11 validation: Cross-entropy=0.81, Accuracy=65.7%
Epoch 12 validation: Cross-entropy=0.79, Accuracy=65.7%
Epoch 13 validation: Cross-entropy=0.77, Accuracy=65.7%
Epoch 14 validation: Cross-entropy=0.75, Accuracy=65.7%
Epoch 15 validation: Cross-entropy=0.73, Accuracy=65.7%
Epoch 16 validation: Cross-entropy=0.71, Accuracy=65.7%
Epoch 17 validation: Cross-entropy=0.70, Accuracy=65.7%
Ep

('Loss=1.30, Accuracy=0.0%', 'Loss=0.34, Accuracy=33.3%')

0