In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

BATCH_SIZE = 64

train_dataset = datasets.MNIST(
    root="../handwriting_letter_recognition",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)
test_dataset = datasets.MNIST(
    root="../handwriting_letter_recognition",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

train_dataset, val_dataset = random_split(train_dataset, [50000, 10000])
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train dataset: {len(train_dataset)}")
print(f"Validation dataset: {len(val_dataset)}")
print(f"Test dataset: {len(test_dataset)}")

Train dataset: 50000
Validation dataset: 10000
Test dataset: 10000


In [2]:
import torch
from torch import nn, optim
from torchsummary import summary

NUM_CLASSES = 10
LEARNING_RATE = 0.001


class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.sigmoid = nn.Sigmoid()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(32 * 13 * 13, NUM_CLASSES)

    def forward(self, x):
        x = self.conv(x)
        x = self.sigmoid(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MNISTModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

summary(model, (1, 28, 28), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
           Sigmoid-2           [-1, 32, 26, 26]               0
         MaxPool2d-3           [-1, 32, 13, 13]               0
            Linear-4                   [-1, 10]          54,090
Total params: 54,410
Trainable params: 54,410
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.37
Params size (MB): 0.21
Estimated Total Size (MB): 0.58
----------------------------------------------------------------


In [3]:
from tqdm import tqdm

history = {
    "accuracy": [],
    "loss": [],
    "val_accuracy": [],
    "val_loss": [],
}

EPOCHS = 20
for epoch in range(EPOCHS):
    model.train()
    loss_sum = 0
    correct = 0

    train_pbar = tqdm(train_loader, total=len(train_loader))
    for i, (images, labels) in enumerate(train_pbar):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        probabilities = torch.softmax(outputs, dim=1)
        predicted = torch.argmax(probabilities, dim=1)
        correct += (predicted == labels).sum().item()

        train_pbar.set_description(
            f"Epoch {epoch} - loss: {loss_sum / (i + 1):.4f} - accuracy: {correct / len(train_dataset):.4f}"
        )

    val_loss_sum = 0
    val_correct = 0
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        val_loss_sum += loss.item()
        probabilities = torch.softmax(outputs, dim=1)
        predicted = torch.argmax(probabilities, dim=1)
        val_correct += (predicted == labels).sum().item()

    history["loss"].append(loss_sum / len(train_loader))
    history["accuracy"].append(correct / len(train_dataset))
    history["val_loss"].append(val_loss_sum / len(val_loader))
    history["val_accuracy"].append(val_correct / len(val_dataset))

    print(
        f"Epoch {epoch} - loss: {loss_sum / len(train_loader):.4f} - accuracy: {correct / len(train_dataset):.4f} - val_loss: {val_loss_sum / len(val_loader):.4f} - val_accuracy: {val_correct / len(val_dataset):.4f}"
    )

Epoch 0 - loss: 0.8559 - accuracy: 0.7410: 100%|██████████| 782/782 [00:00<00:00, 818.26it/s]   


Epoch 0 - loss: 0.8559 - accuracy: 0.7410 - val_loss: 0.3836 - val_accuracy: 0.8846


Epoch 1 - loss: 0.3666 - accuracy: 0.8886: 100%|██████████| 782/782 [00:07<00:00, 99.29it/s]


Epoch 1 - loss: 0.3666 - accuracy: 0.8886 - val_loss: 0.3271 - val_accuracy: 0.8972


Epoch 2 - loss: 0.3219 - accuracy: 0.9024: 100%|██████████| 782/782 [00:14<00:00, 52.57it/s] 


Epoch 2 - loss: 0.3219 - accuracy: 0.9024 - val_loss: 0.3258 - val_accuracy: 0.9056


Epoch 3 - loss: 0.2734 - accuracy: 0.9198: 100%|██████████| 782/782 [00:01<00:00, 493.82it/s]   


Epoch 3 - loss: 0.2734 - accuracy: 0.9198 - val_loss: 0.2336 - val_accuracy: 0.9304


Epoch 4 - loss: 0.2321 - accuracy: 0.9308: 100%|██████████| 782/782 [00:08<00:00, 97.74it/s] 


Epoch 4 - loss: 0.2321 - accuracy: 0.9308 - val_loss: 0.1904 - val_accuracy: 0.9467


Epoch 5 - loss: 0.1966 - accuracy: 0.9421: 100%|██████████| 782/782 [00:14<00:00, 52.44it/s] 


Epoch 5 - loss: 0.1966 - accuracy: 0.9421 - val_loss: 0.1769 - val_accuracy: 0.9476


Epoch 6 - loss: 0.1712 - accuracy: 0.9496: 100%|██████████| 782/782 [00:01<00:00, 654.67it/s]  


Epoch 6 - loss: 0.1712 - accuracy: 0.9496 - val_loss: 0.1545 - val_accuracy: 0.9513


Epoch 7 - loss: 0.1439 - accuracy: 0.9572: 100%|██████████| 782/782 [00:07<00:00, 98.38it/s] 


Epoch 7 - loss: 0.1439 - accuracy: 0.9572 - val_loss: 0.1460 - val_accuracy: 0.9561


Epoch 8 - loss: 0.1225 - accuracy: 0.9638: 100%|██████████| 782/782 [00:08<00:00, 96.38it/s]


Epoch 8 - loss: 0.1225 - accuracy: 0.9638 - val_loss: 0.1002 - val_accuracy: 0.9711


Epoch 9 - loss: 0.1153 - accuracy: 0.9651: 100%|██████████| 782/782 [00:08<00:00, 95.10it/s] 


Epoch 9 - loss: 0.1153 - accuracy: 0.9651 - val_loss: 0.1099 - val_accuracy: 0.9678


Epoch 10 - loss: 0.1047 - accuracy: 0.9681: 100%|██████████| 782/782 [00:07<00:00, 103.99it/s]


Epoch 10 - loss: 0.1047 - accuracy: 0.9681 - val_loss: 0.1300 - val_accuracy: 0.9619


Epoch 11 - loss: 0.0967 - accuracy: 0.9701: 100%|██████████| 782/782 [00:07<00:00, 100.69it/s]


Epoch 11 - loss: 0.0967 - accuracy: 0.9701 - val_loss: 0.0949 - val_accuracy: 0.9720


Epoch 12 - loss: 0.0878 - accuracy: 0.9730: 100%|██████████| 782/782 [00:14<00:00, 54.48it/s]


Epoch 12 - loss: 0.0878 - accuracy: 0.9730 - val_loss: 0.0800 - val_accuracy: 0.9780


Epoch 13 - loss: 0.0774 - accuracy: 0.9764: 100%|██████████| 782/782 [00:07<00:00, 105.13it/s]


Epoch 13 - loss: 0.0774 - accuracy: 0.9764 - val_loss: 0.1029 - val_accuracy: 0.9688


Epoch 14 - loss: 0.0774 - accuracy: 0.9768: 100%|██████████| 782/782 [00:07<00:00, 97.79it/s] 


Epoch 14 - loss: 0.0774 - accuracy: 0.9768 - val_loss: 0.0962 - val_accuracy: 0.9733


Epoch 15 - loss: 0.0735 - accuracy: 0.9770: 100%|██████████| 782/782 [00:08<00:00, 95.75it/s]


Epoch 15 - loss: 0.0735 - accuracy: 0.9770 - val_loss: 0.0907 - val_accuracy: 0.9746


Epoch 16 - loss: 0.0672 - accuracy: 0.9792: 100%|██████████| 782/782 [00:08<00:00, 97.55it/s]


Epoch 16 - loss: 0.0672 - accuracy: 0.9792 - val_loss: 0.0760 - val_accuracy: 0.9776


Epoch 17 - loss: 0.0634 - accuracy: 0.9802: 100%|██████████| 782/782 [00:08<00:00, 95.17it/s]


Epoch 17 - loss: 0.0634 - accuracy: 0.9802 - val_loss: 0.0714 - val_accuracy: 0.9791


Epoch 18 - loss: 0.0599 - accuracy: 0.9815: 100%|██████████| 782/782 [00:07<00:00, 97.78it/s]


Epoch 18 - loss: 0.0599 - accuracy: 0.9815 - val_loss: 0.0765 - val_accuracy: 0.9794


Epoch 19 - loss: 0.0571 - accuracy: 0.9818: 100%|██████████| 782/782 [00:14<00:00, 54.39it/s] 


Epoch 19 - loss: 0.0571 - accuracy: 0.9818 - val_loss: 0.1035 - val_accuracy: 0.9699


In [4]:
import json

with open("history.json", "w") as f:
    json.dump(history, f)

In [5]:
torch.save(model.state_dict(), "mnist_model.pth")