In [None]:
!pip install lrcurve

In [2]:
import sklearn.datasets
import sklearn.model_selection
import sklearn.metrics
import torch

from lrcurve import PlotLearningCurve

In [6]:
# define dataset
x_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(
    *sklearn.datasets.load_iris(return_X_y=True),
    random_state=0
)
x_train = torch.from_numpy(x_train.astype('float32'))
x_test = torch.from_numpy(x_test.astype('float32'))
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# define model
class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(4, 32)
        self.fc2 = torch.nn.Tanh()
        self.fc3 = torch.nn.Linear(32, 16)
        self.fc4 = torch.nn.Tanh()
        self.fc5 = torch.nn.Linear(16, 3)
        self.fc6 = torch.nn.Softmax(dim=-1)

    def forward(self, z):
        z = self.fc1(z)
        z = self.fc2(z)
        z = self.fc3(z)
        z = self.fc4(z)
        z = self.fc5(z)
        z = self.fc6(z)
        return z

# setup
network = Network()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters())
plot = PlotLearningCurve(
    mappings = {
        'loss': { 'line': 'train', 'facet': 'loss' },
        'val_loss': { 'line': 'validation', 'facet': 'loss' },
        'acc': { 'line': 'train', 'facet': 'acc' },
        'val_acc': { 'line': 'validation', 'facet': 'acc' }
    },
    facet_config = {
        'loss': { 'name': 'Cross-Entropy', 'limit': [0, None] },
        'acc': { 'name': 'Accuracy', 'limit': [0, 1] }
    },
    xaxis_config = { 'name': 'Epoch', 'limit': [0, 500] }
)

# optimize model
with plot:
    for epoch in range(500):
        # compute loss
        z_test = network(x_test)
        loss_test = criterion(z_test, y_test)

        optimizer.zero_grad()
        z_train = network(x_train)
        loss_train = criterion(z_train, y_train)
        loss_train.backward()
        optimizer.step()

        # compute accuacy
        accuacy_test = sklearn.metrics.accuracy_score(torch.argmax(z_test, 1).detach().numpy(), y_test)
        accuacy_train = sklearn.metrics.accuracy_score(torch.argmax(z_train, 1).detach().numpy(), y_train)

        # append and update
        plot.append(epoch, {
            'loss': loss_train,
            'val_loss': loss_test,
            'acc': accuacy_train,
            'val_acc': accuacy_test
        })
        plot.draw()