In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

In [27]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [28]:
class IrisDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __getitem__(self, i):
        return torch.tensor(self.X[i], dtype=torch.float), torch.tensor(self.y[i], dtype=torch.int64)

    def __len__(self):
        return len(self.y)

In [41]:
X, y = load_iris(return_X_y=True)
train_idxs, rest_idxs = train_test_split(range(len(y)), train_size=0.8)
dev_idxs, test_idxs = train_test_split(rest_idxs, train_size=0.5)

trainset = IrisDataset(X[train_idxs], y[train_idxs])
devset = IrisDataset(X[dev_idxs], y[dev_idxs])
testset = IrisDataset(X[test_idxs], y[test_idxs])

In [61]:
train_dataloader = DataLoader(trainset, batch_size=8)
dev_dataloader = DataLoader(devset, batch_size=8)
test_dataloader = DataLoader(testset, batch_size=8)

In [135]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 3)
        )

    def forward(self, X):
        logits = self.linear_stack(X)
        return logits

    def fit(
        self,
        train_dataloader,
        dev_dataloader=None,
        num_epochs=50,
        loss_fn=nn.CrossEntropyLoss,
        optimizer=torch.optim.SGD,
        lr=1e-2
    ):
        size = len(train_dataloader.dataset)
        criterion = loss_fn()
        optimizer = optimizer(self.parameters(), lr=lr)

        self.to(device)
        self.train()
        num_batch = 0
        for epoch in range(num_epochs):
            for X_batch, y_batch in train_dataloader:
                batch_size = len(y_batch)

                X_batch, y_true = X_batch.to(device), y_batch.to(device)
    
                logits = self(X_batch)
                y_pred = nn.Softmax(dim=1)(logits)
                loss = criterion(y_pred, y_true)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss, global_step, total_steps = loss.item(), num_batch * batch_size, size*num_epochs
                num_batch += 1
            
            print(f"epoch: {epoch}  loss: {loss:>7f}  batch: {global_step:5d}/{total_steps:>5d}")
            if dev_dataloader is not None:
                self.score(dev_dataloader, loss_fn)
                self.train()
            print("\n")

    def score(self, test_dataloader, loss_fn=nn.CrossEntropyLoss):
        criterion = loss_fn()
        self.eval()

        num_batches = len(test_dataloader)
        num_examples = len(test_dataloader.dataset)
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X_batch, y_batch in test_dataloader:
                X_batch, y_true = X_batch.to(device), y_batch.to(device)

                logits = self(X_batch)
                loss = criterion(logits, y_true)

                test_loss += loss.item()
                correct += (logits.argmax(dim=1) == y_true).sum().item()
            
        test_loss /= num_batches
        correct /= num_examples
        print(f"Test Error Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

In [155]:
clf = MyNet().to(device)
clf.fit(train_dataloader, dev_dataloader, num_epochs=100)

epoch: 0  loss: 1.063545  batch:   112/12000
Test Error Accuracy: 66.7%, Avg loss: 0.958480


epoch: 1  loss: 1.034826  batch:   232/12000
Test Error Accuracy: 66.7%, Avg loss: 0.931232


epoch: 2  loss: 0.997870  batch:   352/12000
Test Error Accuracy: 66.7%, Avg loss: 0.913680


epoch: 3  loss: 0.959804  batch:   472/12000
Test Error Accuracy: 66.7%, Avg loss: 0.894967


epoch: 4  loss: 0.924299  batch:   592/12000
Test Error Accuracy: 66.7%, Avg loss: 0.875628


epoch: 5  loss: 0.893409  batch:   712/12000
Test Error Accuracy: 66.7%, Avg loss: 0.853440


epoch: 6  loss: 0.866690  batch:   832/12000
Test Error Accuracy: 66.7%, Avg loss: 0.825373


epoch: 7  loss: 0.844457  batch:   952/12000
Test Error Accuracy: 66.7%, Avg loss: 0.793421


epoch: 8  loss: 0.826744  batch:  1072/12000
Test Error Accuracy: 66.7%, Avg loss: 0.756622


epoch: 9  loss: 0.811613  batch:  1192/12000
Test Error Accuracy: 66.7%, Avg loss: 0.716865


epoch: 10  loss: 0.798147  batch:  1312/12000
Test Error Acc

In [158]:
clf.score(test_dataloader)

Test Error Accuracy: 86.7%, Avg loss: 0.291586


In [159]:
torch.save(clf.state_dict(), "model.pt")