In [135]:
from numpy.testing import assert_equal
from torch import flatten, load, max, save
from torch.cuda import is_available
from torch.nn import CrossEntropyLoss, Linear, Module, Sequential, Sigmoid
from torch.optim.sgd import SGD
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm.notebook import tqdm

## MNIST MLP

This section serves to get used to creating and training own models. Later this will be transformed into the nets, we use to apply our uncertainty propagation.

### Training dataset download and pre-processing

In [136]:
_train_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(0.13, 0.3081)]
)

In [137]:
train_data = MNIST(root="./train/", download=True, transform=_train_transforms)
train_data

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./train/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=0.13, std=0.3081)
           )

In [138]:
_n_test_samples = int(0.8 * len(train_data))
_n_validat_samples = int(len(train_data) - _n_test_samples)
assert _n_test_samples + _n_validat_samples == len(train_data)
train_set, validat_set = random_split(train_data, [_n_test_samples, _n_validat_samples])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
validat_loader = DataLoader(validat_set, batch_size=16, shuffle=True)

### Network design

The chosen network design resembles the proposed structure from [stats.stackexchange
](https://stats.stackexchange.com/questions/376312/mnist-digit-recognition-what-is-the-best-we-can-get-with-a-fully-connected-nn-o).

In [139]:
class MNIST_MLP(Module):
    def __init__(self):
        super(MNIST_MLP, self).__init__()
        self._net = Sequential(
            Linear(784, 200), Sigmoid(), Linear(200, 80), Sigmoid(), Linear(80, 10)
        )

    def forward(self, x_0):
        passed_input = flatten(x_0, 1)
        return self._net(passed_input)

In [140]:
device = "cuda" if is_available() else "cpu"
model = MNIST_MLP().to(device=device)
assert_equal(tuple(model.children())[-1][-1].out_features, len(train_data.classes))

### Training
#### Loss function and optimization method

In [141]:
loss_func = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=1e-2)

#### Actual training

In [148]:
N_EPOCHS = 3
for epoch in tqdm(range(N_EPOCHS), desc=f"Epoch {epoch}"):
    train_loss = 0.0
    model.train()
    for train_x_0s, train_labels in tqdm(
        train_loader, desc=f"Training during epoch {epoch}"
    ):
        train_x_0s = train_x_0s.to(device)
        train_labels = train_labels.to(device)
        optimizer.zero_grad()
        outputs = model(train_x_0s)
        tmp_train_loss = loss_func(outputs, train_labels)
        tmp_train_loss.backward()
        optimizer.step()
        train_loss += tmp_train_loss.item()

    validat_loss = 0.0
    model.eval()
    for validat_x_0s, validat_labels in tqdm(
        validat_loader, desc=f"Validation during epoch {epoch}"
    ):
        validat_x_0s = validat_x_0s.to(device)
        validat_labels = validat_labels.to(device)
        outputs = model(validat_x_0s)
        tmp_validat_loss = loss_func(outputs, validat_labels)
        validat_loss += tmp_validat_loss.item()

    print(
        f"Epoch: {epoch}, Train Loss: {train_loss/len(train_loader)}, Validation Loss: {validat_loss/len(validat_loader)}"
    )

Epoch 2:   0%|          | 0/3 [00:00<?, ?it/s]

Training during epoch 0:   0%|          | 0/3000 [00:00<?, ?it/s]

Validation during epoch 0:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch: 0, Train Loss: 0.40202696987241504, Validation Loss: 0.3769043511946996


Training during epoch 1:   0%|          | 0/3000 [00:00<?, ?it/s]

Validation during epoch 1:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch: 1, Train Loss: 0.346245320511361, Validation Loss: 0.3367252284909288


Training during epoch 2:   0%|          | 0/3000 [00:00<?, ?it/s]

Validation during epoch 2:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch: 2, Train Loss: 0.31111704791523515, Validation Loss: 0.3112695779800415


In [150]:
model_path = "mnist_mlp.pt"
save(model.state_dict(), model_path)

### Testing

In [144]:
model.load_state_dict(load(model_path))

<All keys matched successfully>

In [145]:
test_data = MNIST(root="./test/", train=False, download=True, transform=ToTensor())
test_data

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./test/
    Split: Test
    StandardTransform
Transform: ToTensor()

In [146]:
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

In [149]:
num_correct = 0.0
for x_test_batch, y_test_batch in tqdm(test_loader):
    model.eval()
    y_test_batch = y_test_batch.to(device)
    x_test_batch = x_test_batch.to(device)
    y_pred_batch = model(x_test_batch)
    _, predicted = max(y_pred_batch, 1)
    num_correct += (predicted == y_test_batch).float().sum()
accuracy = num_correct / (len(test_loader) * test_loader.batch_size)
print(f"Test accuracy: {accuracy}")

  0%|          | 0/625 [00:00<?, ?it/s]

Test accuracy: 0.8238999843597412
