Implement early stopping as a form of regularization. Train a neural network and monitor the validation loss. Stop training when the validation loss starts increasing, and compare the performance with a model trained without early stopping.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

In [2]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2)
        )
        self.classification_head = nn.Sequential(
            nn.Linear(64, 20, bias=True),
            nn.ReLU(),
            nn.Linear(20, 10, bias=True)
        )

    def forward(self, x):
        features = self.net(x)
        return self.classification_head(features.view(features.shape[0], -1))

In [3]:
def train(model, train_loader, criterion, optimizer, device="cpu"):
    model.train()
    running_loss = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * len(data)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    running_loss /= len(train_loader.dataset)
    acc = 100. * correct / len(train_loader.dataset)
    return acc, running_loss

def test(model, test_loader, criterion, device="cpu"):
    model.eval()
    running_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            running_loss += criterion(output, target).item() * len(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    running_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    return acc, running_loss

In [4]:
EPOCHS = 20
PATIENCE = 5
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 1000
LR = 0.001
LOG_INTERVAL = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
train_dataset = MNIST('data/', train=True, download=True, transform=ToTensor())
test_dataset = MNIST('data/', train=False, download=True, transform=ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=True)

In [6]:
model = CNNClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [7]:
print(model)

CNNClassifier(
  (net): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classification_head): Sequential(
    (0): Linear(in_features=64, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=10, bias=True)
  )
)


In [8]:
best_test_loss = float('inf')
current_patience = 0
for epoch in range(1, EPOCHS + 1):
    print(f"Epoch: {epoch}")

    print("\tTraining: ")
    train_acc, train_loss = train(model, train_loader, criterion, optimizer, DEVICE)
    print("\t\tAccuracy: {:.4}%".format(train_acc))
    print("\t\tLoss: {:.4}".format(train_loss))

    print("\tValidation: ")
    test_acc, test_loss = test(model, test_loader, criterion, DEVICE)
    print("\t\tAccuracy: {:.4}%".format(test_acc))
    print("\t\tLoss: {:.4}".format(test_loss))

    if test_loss < best_test_loss:
        best_test_loss = test_loss
        current_patience = 0
    else:
        current_patience += 1

        if current_patience > PATIENCE:
            print(f'\nEarly Stopping! No improvement for {PATIENCE} epochs.')
            break

    print()

Epoch: 1
	Training: 
		Accuracy: 86.63%
		Loss: 0.4113
	Validation: 
		Accuracy: 95.86%
		Loss: 0.1493

Epoch: 2
	Training: 
		Accuracy: 96.6%
		Loss: 0.1146
	Validation: 
		Accuracy: 97.41%
		Loss: 0.08762

Epoch: 3
	Training: 
		Accuracy: 97.56%
		Loss: 0.08123
	Validation: 
		Accuracy: 97.07%
		Loss: 0.09973

Epoch: 4
	Training: 
		Accuracy: 98.1%
		Loss: 0.06335
	Validation: 
		Accuracy: 97.33%
		Loss: 0.09201

Epoch: 5
	Training: 
		Accuracy: 98.39%
		Loss: 0.05404
	Validation: 
		Accuracy: 98.09%
		Loss: 0.06142

Epoch: 6
	Training: 
		Accuracy: 98.63%
		Loss: 0.04393
	Validation: 
		Accuracy: 98.35%
		Loss: 0.05548

Epoch: 7
	Training: 
		Accuracy: 98.83%
		Loss: 0.03823
	Validation: 
		Accuracy: 98.38%
		Loss: 0.05737

Epoch: 8
	Training: 
		Accuracy: 98.96%
		Loss: 0.03176
	Validation: 
		Accuracy: 98.55%
		Loss: 0.05397

Epoch: 9
	Training: 
		Accuracy: 99.16%
		Loss: 0.02695
	Validation: 
		Accuracy: 98.63%
		Loss: 0.04925

Epoch: 10
	Training: 
		Accuracy: 99.26%
		Loss: 0.