In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [2]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='checkpoint.pt', verbose=False):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [3]:
class NeuralNetModel(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_size):
        super().__init__()
        self.activation = nn.ReLU()
        self.layer1 = nn.Linear(input_size, hidden_layer_size)
        self.layer2 = nn.Linear(hidden_layer_size, output_size)

    def forward(self, x):
        y = self.layer1(x)
        y = self.activation(y)
        y = self.layer2(y)
        return y

In [4]:
def train(model, train_loader, val_loader, criterion, optimizer, n_epochs=100, patience=5):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    train_losses, val_losses, val_accuracies = [], [], []

    for epoch in range(1, n_epochs + 1):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            images = images.reshape(-1, 28 * 28)

            optimizer.zero_grad()
            prediction = model(images)
            loss = criterion(prediction, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                images = images.reshape(-1, 28 * 28)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)

                _, predicted = outputs.max(1)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)
        accuracy = 100. * correct / len(val_loader.dataset)
        val_accuracies.append(accuracy)

        print(f"Epoch {epoch}/{n_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Val Accuracy: {accuracy:.2f}%")

        # Check early stopping
        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    # Plotting losses and accuracy
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Training Loss")
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training and Validation Loss")

    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Validation Accuracy", color="orange")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.title("Validation Accuracy")
    plt.tight_layout()
    plt.show()

In [5]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parameters
batch_size = 64
validation_split = 0.2
n_epochs = 20
patience = 5

# Data transformation and loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Split dataset into training and validation sets
train_size = int((1 - validation_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [6]:
input_size = 28 * 28
hidden_layer_size = 100
output_size = 10
model = NeuralNetModel(input_size, hidden_layer_size, output_size).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Run training
train(model, train_loader, val_loader, loss_function, optimizer, n_epochs=n_epochs, patience=patience)

Epoch 1/20 - Train Loss: 0.4300 - Val Loss: 0.2754 - Val Accuracy: 91.99%
Validation loss decreased (inf --> 0.275433). Saving model...
Epoch 2/20 - Train Loss: 0.2435 - Val Loss: 0.2104 - Val Accuracy: 94.17%
Validation loss decreased (0.275433 --> 0.210435). Saving model...
Epoch 3/20 - Train Loss: 0.1853 - Val Loss: 0.1714 - Val Accuracy: 95.12%
Validation loss decreased (0.210435 --> 0.171395). Saving model...
Epoch 4/20 - Train Loss: 0.1523 - Val Loss: 0.1638 - Val Accuracy: 95.30%
Validation loss decreased (0.171395 --> 0.163802). Saving model...
Epoch 5/20 - Train Loss: 0.1297 - Val Loss: 0.1490 - Val Accuracy: 95.38%
Validation loss decreased (0.163802 --> 0.148977). Saving model...
