In [None]:
import torch
import torchvision.datasets as datasets  # for Mist
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
from torch import optim  # For optimizers like SGD, Adam, etc.
from torch import nn  # To inherit our neural network
from torch.utils.data import DataLoader  # For management of the dataset (batches)
from tqdm import tqdm  # For nice progress bar!
import matplotlib.pyplot as plt
import pandas as pd


# Set device cuda for GPU if it's available otherwise run on the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

best_acc = 0.9954


class NN(nn.Module):
    def __init__(self, input_size, num_classes):

        super(NN, self).__init__()

        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)

        self.net = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes),
        )

        self.net.apply(init_weights)

    def forward(self, x):
        return self.net(x)

    def _train(self, *, train_loader, test_loader, num_epochs, criterion, optimizer):
        global best_acc

        losses = []
        test_acc = []
        for epoch in range(num_epochs):
            avg_loss = 0

            for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
                # Get data to cuda if possible
                data = data.to(device=device)
                targets = targets.to(device=device)

                # Get to correct shape
                data = data.reshape(data.shape[0], -1)

                # Forward
                scores = self(data)
                loss = criterion(scores, targets)

                avg_loss += loss.item()

                # Backward
                optimizer.zero_grad()
                loss.backward()

                # Gradient descent or adam step
                optimizer.step()

            losses.append(avg_loss / len(train_loader))

            # Test accuracy for current epoch
            curr_test_acc = self._test(test_loader)
            test_acc.append(curr_test_acc)

            if curr_test_acc > best_acc:
                best_acc = curr_test_acc
                # Save model
                torch.save(self.state_dict(), "model.pt")
                print("BEST", end=" ")

            print(f"Test acc: {curr_test_acc}")

        # Plot graphs of test accuracy and loss
        plt.title("Test accuracy")
        plt.plot(test_acc, ".-g")
        plt.show()

        plt.title("Loss")
        plt.plot(losses, ".-r")
        plt.show()

    def _test(self, loader):
        num_correct = 0
        num_samples = 0
        self.eval()

        with torch.no_grad():
            # Loop through the data
            for x, y in loader:

                # Move data to device
                x = x.to(device=device)
                y = y.to(device=device)

                # Get to correct shape
                x = x.reshape(x.shape[0], -1)

                # Forward pass
                scores = self(x)
                _, predictions = scores.max(1)

                # Check how many we got correct
                num_correct += (predictions == y).sum().item()

                # Keep track of number of samples
                num_samples += predictions.size(0)

        self.train()
        return num_correct / num_samples

    def _predict(self, loader):
        data = {
            "ID": [],
            "target": [],
        }

        self.eval()

        with torch.no_grad():
            # Loop through the data
            for i, (x, y) in enumerate(loader):

                # Move data to device
                x = x.to(device=device)
                y = y.to(device=device)

                # Get to correct shape
                x = x.reshape(x.shape[0], -1)

                # Forward pass
                scores = self(x)
                _, predictions = scores.max(1)

                data["ID"].extend(range(i * batch_size, (i + 1) * batch_size))
                data["target"].extend(predictions.tolist())

        self.train()
        return data


# Hyperparameters
input_size = 784
num_classes = 10
learning_rate = 0.001
batch_size = 100
num_epochs = 500

In [None]:
# Load Data
root = "data"

# Transform train data by rotating images by at most 10 degrees and slightly transforming the image (scaling and shifting)
train_transform = transforms.Compose(
    [
        transforms.RandomRotation(degrees=10),
        transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
    ]
)

# Transform test data to tensor form
test_transform = transforms.ToTensor()

# Augmented train dataset
train_dataset = torch.utils.data.ConcatDataset(
    [datasets.MNIST(root=root, train=True, transform=test_transform, download=True)]
    + [
        datasets.MNIST(root=root, train=True, transform=train_transform, download=True)
        for _ in range(3)
    ]
)
test_dataset = datasets.MNIST(
    root=root, train=False, transform=test_transform, download=True
)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Initialize network
model = NN(input_size=input_size, num_classes=num_classes).to(device)
model.load_state_dict(torch.load("model01.pt"))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(
    model.parameters(), lr=0.001, weight_decay=0.001
)  # Adam with Weight Decay
# optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9, nesterov=True)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[16, 32], gamma=0.1)

# Train
model._train(
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    criterion=criterion,
    optimizer=optimizer,
)

print(f"Train acc: {model._test(train_loader)}")
print(f"Test acc:  {model._test(test_loader)}")

  model.load_state_dict(torch.load("model01.pt"))
100%|██████████| 2400/2400 [00:54<00:00, 43.98it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.95it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:53<00:00, 45.15it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 45.15it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:52<00:00, 45.58it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.42it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:53<00:00, 45.05it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:53<00:00, 44.85it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:53<00:00, 44.89it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 44.54it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.75it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:54<00:00, 44.17it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.88it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 44.57it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:54<00:00, 44.26it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:53<00:00, 44.57it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.95it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:53<00:00, 44.82it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 44.62it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.96it/s]


Test acc: 0.9926999807357788


100%|██████████| 2400/2400 [00:54<00:00, 44.40it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:52<00:00, 45.63it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:53<00:00, 44.85it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:53<00:00, 45.19it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 44.85it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:55<00:00, 43.33it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:53<00:00, 44.46it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:54<00:00, 44.34it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:53<00:00, 45.16it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:54<00:00, 44.17it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.29it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 45.70it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.76it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 46.02it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.00it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 44.50it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:54<00:00, 44.17it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.89it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.54it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:54<00:00, 44.10it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 44.75it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.52it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 44.66it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:52<00:00, 45.84it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.33it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 45.00it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 44.80it/s]


Test acc: 0.9929999709129333


100%|██████████| 2400/2400 [00:53<00:00, 44.87it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:53<00:00, 45.23it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:53<00:00, 45.26it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.56it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.44it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 45.14it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.63it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:54<00:00, 44.34it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 45.01it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 45.07it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:53<00:00, 44.64it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:52<00:00, 45.40it/s]


BEST Test acc: 0.9950999617576599


100%|██████████| 2400/2400 [00:53<00:00, 44.91it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:54<00:00, 43.96it/s]


Test acc: 0.9922999739646912


100%|██████████| 2400/2400 [00:55<00:00, 43.55it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:54<00:00, 44.24it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.58it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:54<00:00, 44.33it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.99it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.56it/s]


Test acc: 0.9949999451637268


100%|██████████| 2400/2400 [00:53<00:00, 44.93it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:53<00:00, 45.03it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:52<00:00, 45.30it/s]


Test acc: 0.9929999709129333


100%|██████████| 2400/2400 [00:52<00:00, 45.84it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:53<00:00, 44.93it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 45.27it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:53<00:00, 44.44it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.72it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:54<00:00, 44.07it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:53<00:00, 44.60it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 44.89it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.78it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.88it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:51<00:00, 46.30it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 44.75it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:53<00:00, 44.71it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:54<00:00, 44.03it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:54<00:00, 43.95it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:53<00:00, 44.46it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:54<00:00, 44.36it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:54<00:00, 44.33it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:53<00:00, 44.53it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:54<00:00, 44.30it/s]


Test acc: 0.9950999617576599


100%|██████████| 2400/2400 [00:54<00:00, 44.27it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:54<00:00, 44.40it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:53<00:00, 44.45it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:53<00:00, 44.94it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:53<00:00, 44.93it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:54<00:00, 43.93it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:54<00:00, 44.03it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 45.16it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.67it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.49it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.28it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.44it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.67it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:53<00:00, 45.03it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.97it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 45.04it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:53<00:00, 44.65it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.52it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.49it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.63it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 44.97it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 45.20it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:53<00:00, 45.16it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.50it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.45it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.37it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.35it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:53<00:00, 44.84it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.86it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:53<00:00, 44.69it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 45.23it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.81it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:53<00:00, 45.18it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 45.26it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:53<00:00, 45.26it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.42it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:53<00:00, 45.08it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 44.79it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:56<00:00, 42.49it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:54<00:00, 44.12it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.45it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.42it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:53<00:00, 45.10it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:53<00:00, 45.26it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.55it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:53<00:00, 45.08it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.49it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.72it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:54<00:00, 44.29it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 45.10it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:54<00:00, 44.42it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:53<00:00, 45.20it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:54<00:00, 44.28it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:54<00:00, 44.40it/s]


Test acc: 0.9929999709129333


100%|██████████| 2400/2400 [00:51<00:00, 46.36it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:51<00:00, 46.64it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.40it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:52<00:00, 45.71it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.61it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:51<00:00, 46.75it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.94it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:51<00:00, 46.48it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.72it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.32it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 46.12it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.73it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.22it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.43it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.01it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:53<00:00, 44.94it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:51<00:00, 46.18it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.84it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 45.35it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.36it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.35it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.35it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.39it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.93it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:51<00:00, 46.82it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.49it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:51<00:00, 46.82it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.77it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.99it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:52<00:00, 45.77it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.31it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.54it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.26it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:51<00:00, 46.22it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.75it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.36it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 44.99it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 45.25it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:53<00:00, 45.12it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.51it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.32it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:53<00:00, 44.85it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.60it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.43it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 45.39it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.56it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.65it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:51<00:00, 46.73it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:52<00:00, 45.94it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.66it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 45.19it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:54<00:00, 44.08it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 44.55it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.47it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:51<00:00, 46.40it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 45.62it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:53<00:00, 44.82it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 44.81it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:53<00:00, 44.50it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:54<00:00, 43.97it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:53<00:00, 45.04it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.85it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.62it/s]


Test acc: 0.9926999807357788


100%|██████████| 2400/2400 [00:52<00:00, 45.67it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.46it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 45.73it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.36it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:52<00:00, 45.31it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 46.02it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.76it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.62it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 45.58it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 46.06it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.76it/s]


Test acc: 0.9929999709129333


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:53<00:00, 44.70it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:53<00:00, 45.22it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 45.18it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.67it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:53<00:00, 45.23it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 46.15it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.47it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.63it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.86it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.46it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 45.76it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.37it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:53<00:00, 44.84it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:53<00:00, 45.26it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:52<00:00, 45.72it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.82it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 45.72it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.61it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 46.10it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 45.74it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.57it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.72it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.81it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.73it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.74it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 46.01it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.25it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.61it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.27it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.15it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.92it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.47it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.85it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.46it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 45.68it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 46.09it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:51<00:00, 46.38it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.73it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.65it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:52<00:00, 45.84it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.18it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.27it/s]


Test acc: 0.9949999451637268


100%|██████████| 2400/2400 [00:51<00:00, 46.31it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:51<00:00, 46.62it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.69it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:51<00:00, 46.17it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.28it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 46.10it/s]


Test acc: 0.9950999617576599


100%|██████████| 2400/2400 [00:51<00:00, 46.68it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.16it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.80it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.01it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 46.14it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.30it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.73it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 45.55it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:51<00:00, 46.16it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.82it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.70it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.81it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.59it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.20it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.90it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.32it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.23it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.19it/s]


BEST Test acc: 0.995199978351593


100%|██████████| 2400/2400 [00:52<00:00, 46.00it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.43it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.38it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.79it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:51<00:00, 46.46it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 45.82it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.18it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.21it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.47it/s]


BEST Test acc: 0.9953999519348145


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 46.12it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 45.97it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 46.09it/s]


Test acc: 0.9950999617576599


100%|██████████| 2400/2400 [00:52<00:00, 46.08it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.77it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.19it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.95it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.65it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 45.99it/s]


Test acc: 0.9925999641418457


100%|██████████| 2400/2400 [00:51<00:00, 46.47it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.25it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9953999519348145


100%|██████████| 2400/2400 [00:51<00:00, 46.21it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.69it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.28it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.25it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.54it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.40it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:51<00:00, 46.26it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 45.91it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.33it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.23it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 46.07it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 46.01it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.81it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.90it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 46.08it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:51<00:00, 46.51it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 45.97it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.26it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.70it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 46.06it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:52<00:00, 46.08it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 46.04it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.98it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:51<00:00, 46.31it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 46.14it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.18it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 46.01it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:51<00:00, 46.37it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 46.08it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.46it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:52<00:00, 46.08it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 46.09it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.31it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.32it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 45.98it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 46.15it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 45.38it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.10it/s]


Test acc: 0.9952999949455261


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.92it/s]


Test acc: 0.9930999875068665


100%|██████████| 2400/2400 [00:51<00:00, 46.22it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.28it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.35it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:51<00:00, 46.48it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:51<00:00, 46.37it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:51<00:00, 46.36it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:51<00:00, 46.32it/s]


Test acc: 0.9949999451637268


100%|██████████| 2400/2400 [00:51<00:00, 46.55it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:51<00:00, 46.28it/s]


Test acc: 0.9928999543190002


100%|██████████| 2400/2400 [00:52<00:00, 45.93it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 46.11it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.89it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.91it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 45.96it/s]


Test acc: 0.9929999709129333


100%|██████████| 2400/2400 [00:52<00:00, 45.92it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.07it/s]


Test acc: 0.993399977684021


100%|██████████| 2400/2400 [00:52<00:00, 45.91it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:52<00:00, 45.99it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:52<00:00, 45.90it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:51<00:00, 46.29it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.97it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 46.09it/s]


Test acc: 0.9940999746322632


100%|██████████| 2400/2400 [00:52<00:00, 46.04it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.44it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 46.12it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.36it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.02it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 46.13it/s]


Test acc: 0.9937999844551086


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:51<00:00, 46.37it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:51<00:00, 46.21it/s]


Test acc: 0.9935999512672424


100%|██████████| 2400/2400 [00:52<00:00, 45.82it/s]


Test acc: 0.9932999610900879


100%|██████████| 2400/2400 [00:51<00:00, 46.49it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:51<00:00, 46.24it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.02it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:52<00:00, 46.05it/s]


Test acc: 0.9947999715805054


100%|██████████| 2400/2400 [00:52<00:00, 46.14it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.19it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.29it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.54it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.04it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:51<00:00, 46.15it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:51<00:00, 46.22it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:51<00:00, 46.45it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.29it/s]


Test acc: 0.9939000010490417


100%|██████████| 2400/2400 [00:52<00:00, 46.14it/s]


Test acc: 0.9941999912261963


100%|██████████| 2400/2400 [00:52<00:00, 46.02it/s]


Test acc: 0.9948999881744385


100%|██████████| 2400/2400 [00:51<00:00, 46.18it/s]


Test acc: 0.995199978351593


100%|██████████| 2400/2400 [00:51<00:00, 46.19it/s]


Test acc: 0.9945999979972839


100%|██████████| 2400/2400 [00:52<00:00, 46.03it/s]


Test acc: 0.9943999648094177


100%|██████████| 2400/2400 [00:51<00:00, 46.31it/s]


Test acc: 0.9934999942779541


100%|██████████| 2400/2400 [00:52<00:00, 46.13it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:52<00:00, 45.76it/s]


Test acc: 0.995199978351593


100%|██████████| 2400/2400 [00:51<00:00, 46.31it/s]


Test acc: 0.9936999678611755


100%|██████████| 2400/2400 [00:51<00:00, 46.34it/s]


Test acc: 0.9944999814033508


100%|██████████| 2400/2400 [00:52<00:00, 45.97it/s]


Test acc: 0.9932000041007996


100%|██████████| 2400/2400 [00:52<00:00, 46.13it/s]


Test acc: 0.9939999580383301


100%|██████████| 2400/2400 [00:52<00:00, 45.86it/s]


Test acc: 0.9942999482154846


100%|██████████| 2400/2400 [00:51<00:00, 46.34it/s]


Test acc: 0.9946999549865723


100%|██████████| 2400/2400 [00:51<00:00, 46.17it/s]


In [196]:
# print(f"Train acc: {model._test(train_loader)}")
# print(f"Test acc:  {model._test(test_loader)}")

In [12]:
# Load model and create CSV for submission
net = NN(input_size=input_size, num_classes=num_classes).to(device)
net.load_state_dict(torch.load("model.pt"))
print(net._test(test_loader))

# data = net._predict(test_loader)
# df = pd.DataFrame(data)
# df.to_csv("submission.csv", index=False)

  net.load_state_dict(torch.load("model.pt"))


0.9954


# Results

### Arch 1

First I tried different optimizers on a network with 100 hidden neurons. As can be seen, AdamW is clearly best here.

| Optimizer | Params                                                 | Accuracy |
|-----------|--------------------------------------------------------|----------|
| NAdam     | epochs = 100, lr = 0.001, decay = 0.001                | 0.976099967956543 |
| AdamW     | epochs = 100, lr = 0.001, decay = 0.001                | 0.979699969291687 |
| AdamW     | epochs = 100, lr = 0.001, decay = 0.001                | 0.9790999889373779 |
| AdamW     | epochs = 100, lr = 0.01, decay = 0.001                 | 0.9747999906539917 |
| Adadelta  | epochs = 100, decay = 0.001                            | 0.974399983882904 |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 100),
    nn.ReLU(),
    nn.Linear(100, num_classes),
)
```

### Arch 2

Then I tried adding many more intermediary layers so that it can learn more conceptual patterns, turns out it didn't work here.

| Optimizer | Params                                                 | Accuracy |
|-----------|--------------------------------------------------------|----------|
| AdamW     | epochs = 100, lr = 0.01, decay = 0.001                 | 0.973099946975708 |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 100),
    nn.ReLU(),
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 40),
    nn.ReLU(),
    nn.Linear(40, 20),
    nn.ReLU(),
    nn.Linear(20, num_classes),
)
```

### Arch 3

I then tried to lower the neurons in the hidden layer of the first architecture in hopes that the memory costs will be reduced, didn't work.

| Optimizer | Params                                                 | Accuracy |
|-----------|--------------------------------------------------------|----------|
| AdamW     | epochs = 50, lr = 0.01, decay = 0.001                 | 0.9698999524116516 |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 50),
    nn.ReLU(),
    nn.Linear(50, num_classes),
)
```

### Arch 4

Back to more hidden layers. This time I increased considerably the size of them. It turns out this was a great improvement of 0.02 in accuracy.

| Optimizer | Params                                                 | Accuracy |
|-----------|--------------------------------------------------------|----------|
| AdamW     | epochs = 50, lr = 0.001, decay = 0.001, augmentation  | 0.9926999807357788 |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024, num_classes),
)
```

### Arch 4.2

I tried lowering the number of neurons to 512, the results were slightly inferior.

| Optimizer | Params                                                 | Graphs | Accuracy |
|-----------|--------------------------------------------------------|--------|----------|
| AdamW     | epochs = 50, lr = 0.001, decay = 0.001, augmentation  | [acc](graphs/acc03.png), [loss](graphs/loss03.png) | 0.9948999881744385 |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Linear(512, num_classes),
)
```

### Arch 5

I then added batch normalization so that the weights are kept in a predictable interval, which helps the network learn the parameters more easily.
Weight initialization has been added for the second run too, which further improved the results.

| Optimizer | Params                                                 | Graphs | Model | Accuracy |
|-----------|--------------------------------------------------------|--------|-------|---------|
| AdamW     | epochs = 50, lr = 0.001, decay = 0.001, augmentation  | [acc](graphs/acc01.png), [loss](graphs/loss01.png) | [model](model01.pt) | 0.9948999881744385 |
| AdamW     | epochs = 50, lr = 0.001, decay = 0.001, augmentation, xavier weights  | [acc](graphs/acc02.png), [loss](graphs/loss02.png) | [model](model02.pt) | 0.9950999617576599 |
| AdamW     | epochs = 500, lr = 0.001, decay = 0.001, augmentation, xavier weights  | - | [model](model03.pt) | **0.9954** |

```py
self.net = nn.Sequential(
    nn.Linear(input_size, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(),
    nn.Linear(1024, num_classes),
)
```