In [111]:
import torch
import torchvision as tv

In [112]:
BATCH_SIZE = 64

In [113]:
train_dataset = tv.datasets.MNIST('.', train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.MNIST('.', train=False, transform=tv.transforms.ToTensor(), download=True)

In [114]:
train, val = torch.utils.data.random_split(train_dataset, [int(0.9 * len(train_dataset)), int(0.1 * len(train_dataset))])

train = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE,shuffle=True)
val = torch.utils.data.DataLoader(val, batch_size=BATCH_SIZE,shuffle=False)
test = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=False)

In [132]:
# ===== MODEL =====
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(28**2, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# ===== Loss and Optimizer =====
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())


for epoch in range(5):
    # ===== TRAIN =====
    epoch_loss = 0
    epoch_accuracy = 0
    model.train()
    
    for X, y in train:
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward() 
        optimizer.step()

        _, preds = torch.max(y_pred, 1)
        accuracy = (preds == y).float().mean()
        epoch_accuracy += accuracy.item()
        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train)
    train_acc = epoch_accuracy / len(train)

    # ===== VALIDATION =====
    model.eval()
    val_loss = 0
    val_accuracy = 0

    with torch.no_grad():
        for X, y in val:
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            _, preds = torch.max(y_pred, 1)
            val_accuracy += (preds == y).float().mean()
            val_loss += loss.item()

    val_loss /= len(val)
    val_acc = val_accuracy / len(val)

    # ===== PRINT =====
    print(f'Epoch №{epoch+1}')
    print(f'Loss train: {train_loss:.4f}, Acc train: {train_acc:.4f}')
    print(f'Loss val: {val_loss:.4f}, Acc val: {val_acc:.4f}\n')


Epoch №1
Loss train: 0.3564, Acc train: 0.9047
Loss val: 0.1962, Acc val: 0.9458

Epoch №2
Loss train: 0.1606, Acc train: 0.9532
Loss val: 0.1417, Acc val: 0.9596

Epoch №3
Loss train: 0.1140, Acc train: 0.9670
Loss val: 0.1140, Acc val: 0.9673

Epoch №4
Loss train: 0.0881, Acc train: 0.9737
Loss val: 0.1050, Acc val: 0.9694

Epoch №5
Loss train: 0.0702, Acc train: 0.9790
Loss val: 0.0903, Acc val: 0.9747



In [133]:
model.eval()
test_loss = 0
test_correct = 0
test_total = 0

with torch.no_grad():
    for X, y in test:
        y_pred = model(X)
        test_loss += loss_fn(y_pred, y).item() * y.size(0)
        _, preds = torch.max(y_pred, 1)
        test_correct += (preds == y).sum().item()
        test_total += y.size(0)

test_loss /= test_total
test_acc = test_correct / test_total

print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')

Test Loss: 0.0802, Test Accuracy: 0.9759
