In [1]:
import unittest
from pathlib import Path
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from tools.torch.trainers import TorchTrainer
from tools.torch.listeners import Listeners, TensorBoardLossReporter, TensorBoardModelReporter

In [2]:
test_dir = Path('test')
root = test_dir.joinpath('data')

In [3]:
device = torch.device('cpu')

In [4]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [5]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root=root,
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root=root,
    train=False,
    download=True,
    transform=ToTensor(),
)

In [6]:
model = NeuralNetwork()

In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [9]:
checkpoint_dir = test_dir.joinpath('checkpoint')
log_dir = test_dir.joinpath('log')

In [10]:
listener = Listeners([
    TensorBoardLossReporter(log_dir),
    TensorBoardModelReporter(log_dir),
])

In [11]:
batch_size = 16
batch_multi = 1

trainer = TorchTrainer(
    name=f'mnist_{batch_size}_{batch_multi}',
    epochs=5,
    device=device,
    batch_size=batch_size,
    train_data=training_data,
    val_data=test_data,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    listener=listener,
    batch_multi=batch_multi,
    checkpoint_dir=checkpoint_dir,
)

In [12]:
trainer.start()

epoch: 6, step: 492, loss: 0.844:  13%|████████████████████▉                                                                                                                                           | 492/3750 [00:10<01:07, 48.47it/s]


KeyboardInterrupt: 