# Basic EarlyStopping class


    """
        basically we have best loss which is lowest val loss we seen during training,
        
        The min_delta parameter represents the minimum change in the
        monitored quantity to qualify as an improvement.

        To better understand min_delta relation with current val_loss and best_loss Suppose:

                self.best_loss is 0.5
                val_loss is 0.4
                self.min_delta is 0.1
            Then, we have:
                self.best_loss - val_loss = 0.5 - 0.4 = 0.1
                0.1 >= 0.1

            Since 0.1 is greater than or equal to 0.1, the condition is satisfied, indicating an improvement.

            Now let's consider another scenario:
            Suppose:
                self.best_loss is 0.5
                val_loss is 0.55
                self.min_delta is 0.1
            Then, we have:
                self.best_loss - val_loss = 0.5 - 0.55 = -0.05
                -0.05 >= 0.1
            Since -0.05 is not greater than or equal to 0.1, the condition is not satisfied, indicating no improvement.
        """


In [4]:
import copy
import torch

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_model = None
        self.best_loss = None
        self.counter = 0
        self.status = ""

    def __call__(self, model, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = copy.deepcopy(model.state_dict())
        elif self.best_loss - val_loss >= self.min_delta:
            self.best_model = copy.deepcopy(model.state_dict())
            self.best_loss = val_loss
            self.counter = 0
            self.status = f"Improvement found, counter reset to {self.counter}"
        else:
            self.counter += 1
            self.status = f"No improvement in the last {self.counter} epochs"
            if self.counter >= self.patience:
                self.status = f"Early stopping triggered after {self.counter} epochs."
                if self.restore_best_weights:
                    model.load_state_dict(self.best_model)
                return True
        return False

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define a simple neural network architecture for MNIST
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Input size is 28x28 for MNIST images
        self.fc2 = nn.Linear(128, 10)  # Output size is 10 for 10 classes (digits 0-9)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input images
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the SimpleNN model
model = SimpleNN()

# Define transformations and load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize EarlyStopping object
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

# Training loop
for epoch in range(200):  # Assume 20 epochs for demonstration
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    accuracy = 100 * correct / total

    print(f'Epoch {epoch+1}/{200}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Check early stopping criteria
    print(early_stopping.status)
    if early_stopping(model, val_loss):
        print("Early stopping triggered!")
        break

# Optionally, retrieve the best model weights
best_model = model  # Assuming early_stopping restores the best model automatically


Epoch 1/200, Validation Loss: 0.2762, Accuracy: 92.17%

Epoch 2/200, Validation Loss: 0.2262, Accuracy: 93.43%

Epoch 3/200, Validation Loss: 0.1789, Accuracy: 94.53%
Improvement found, counter reset to 0
Epoch 4/200, Validation Loss: 0.1628, Accuracy: 95.18%
Improvement found, counter reset to 0
Epoch 5/200, Validation Loss: 0.1366, Accuracy: 95.91%
Improvement found, counter reset to 0
Epoch 6/200, Validation Loss: 0.1208, Accuracy: 96.24%
Improvement found, counter reset to 0
Epoch 7/200, Validation Loss: 0.1120, Accuracy: 96.56%
Improvement found, counter reset to 0
Epoch 8/200, Validation Loss: 0.1005, Accuracy: 96.86%
Improvement found, counter reset to 0
Epoch 9/200, Validation Loss: 0.1076, Accuracy: 96.49%
Improvement found, counter reset to 0
Epoch 10/200, Validation Loss: 0.0944, Accuracy: 96.97%
No improvement in the last 1 epochs
Epoch 11/200, Validation Loss: 0.0959, Accuracy: 97.07%
Improvement found, counter reset to 0
Epoch 12/200, Validation Loss: 0.0850, Accuracy: 97

# Lets modify it a little bit to add more print statements

In [1]:
import copy

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_model = None
        self.best_loss = None
        self.best_accuracy = None  # New attribute to track best accuracy
        self.counter = 0
        self.status = ""

    def __call__(self, model, val_loss, val_accuracy):
        if self.best_loss is None or val_loss < self.best_loss:  # Update if val_loss is better
            self.best_loss = val_loss
            self.best_accuracy = val_accuracy  # Update best accuracy along with best loss
            self.best_model = copy.deepcopy(model.state_dict())
            self.counter = 0
            self.status = f"Improvement found, counter reset to {self.counter}. " \
                          # f"Best Loss: {self.best_loss:.4f}, Best Accuracy: {self.best_accuracy:.2f}%"

        elif self.best_loss - val_loss >= self.min_delta:
            self.best_model = copy.deepcopy(model.state_dict())
            self.best_loss = val_loss
            self.counter = 0
            self.status = f"Improvement found, counter reset to {self.counter}. " \
                          # f"Best Loss: {self.best_loss:.4f}, Best Accuracy: {self.best_accuracy:.2f}%"

        else:
            self.counter += 1
            self.status = f"No improvement in the last {self.counter} epochs. " \
                          # f"Best Loss: {self.best_loss:.4f}, Best Accuracy: {self.best_accuracy:.2f}%"
            if self.counter >= self.patience:
                self.status = f"Early stopping triggered after {self.counter} epochs. " \
                              f"Best Loss: {self.best_loss:.4f}, Best Accuracy: {self.best_accuracy:.2f}%"
                if self.restore_best_weights:
                    model.load_state_dict(self.best_model)
                return True
        return False


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define a simple neural network architecture for MNIST
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Input size is 28x28 for MNIST images
        self.fc2 = nn.Linear(128, 10)  # Output size is 10 for 10 classes (digits 0-9)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input images
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize the SimpleNN model
model = SimpleNN()

# Define transformations and load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize EarlyStopping object
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

# Training loop
for epoch in range(200):  # Assume 20 epochs for demonstration
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    accuracy = 100 * correct / total

    print(f'Epoch {epoch+1}/{200}, Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Check early stopping criteria
    # print(early_stopping.status)
    if early_stopping(model, val_loss,accuracy):
        print(early_stopping.status)
        print("Early stopping triggered!")
        break

# Optionally, retrieve the best model weights
best_model = model  # Assuming early_stopping restores the best model automatically


Epoch 1/200, Validation Loss: 0.3109, Accuracy: 90.69%
Epoch 2/200, Validation Loss: 0.2487, Accuracy: 92.82%
Epoch 3/200, Validation Loss: 0.1969, Accuracy: 94.38%
Epoch 4/200, Validation Loss: 0.1647, Accuracy: 95.00%
Epoch 5/200, Validation Loss: 0.1381, Accuracy: 95.64%
Epoch 6/200, Validation Loss: 0.1356, Accuracy: 95.87%
Epoch 7/200, Validation Loss: 0.1156, Accuracy: 96.28%
Epoch 8/200, Validation Loss: 0.1016, Accuracy: 96.82%
Epoch 9/200, Validation Loss: 0.1042, Accuracy: 96.65%
Epoch 10/200, Validation Loss: 0.0943, Accuracy: 97.07%
Epoch 11/200, Validation Loss: 0.0892, Accuracy: 97.22%
Epoch 12/200, Validation Loss: 0.1093, Accuracy: 96.69%
Epoch 13/200, Validation Loss: 0.0894, Accuracy: 97.12%
Epoch 14/200, Validation Loss: 0.0837, Accuracy: 97.36%
Epoch 15/200, Validation Loss: 0.0897, Accuracy: 97.26%
Epoch 16/200, Validation Loss: 0.0947, Accuracy: 97.17%
Epoch 17/200, Validation Loss: 0.0833, Accuracy: 97.44%
Epoch 18/200, Validation Loss: 0.0795, Accuracy: 97.50%
E

# Another flavour of Earlystopping

In [8]:
import torch
import copy

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, restore_best_weights=True, path="best_model.pth"):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.path = path

        self.best_model = None
        self.best_loss = None
        self.best_acc = None
        self.counter = 0
        self.early_stop = False

    def save_checkpoint(self, model, val_acc):
        torch.save({'model_state_dict': model.state_dict(),
                    'val_acc': val_acc}, self.path)

    def should_stop(self, model, val_loss, val_acc):
        if self.best_loss is None:
            self._initialize_best_loss(val_loss, val_acc)
        elif self._is_improvement(val_loss):
            self._update_best_loss(model, val_loss, val_acc)
            self.counter = 0
        else:
            self._no_improvement()
            self.counter += 1
            if self.counter >= self.patience:
                self._trigger_early_stopping()
                return True
        return False

    def _initialize_best_loss(self, val_loss, val_acc):
        self.best_loss = val_loss
        self.best_acc = val_acc
        self.best_model = copy.deepcopy(model.state_dict())
        print("Initialized best loss and accuracy.")

    def _is_improvement(self, val_loss):
        improvement = self.best_loss - val_loss >= self.min_delta
        if improvement:
            print("Improved validation loss.")
        return improvement

    def _update_best_loss(self, model, val_loss, val_acc):
        self.best_loss = val_loss
        self.best_acc = val_acc
        self.best_model = copy.deepcopy(model.state_dict())
        print("Updated best loss and accuracy.")

    def _no_improvement(self):
        print("No improvement in validation loss.")

    def _trigger_early_stopping(self):
        print(f"Early stopping triggered after {self.counter} epochs with no loss improvement.")
        if self.restore_best_weights:
            print("Restoring best weights.")
            model.load_state_dict(self.best_model)
            self.save_checkpoint(model, self.best_acc)
            print(f"Accuracy of the best model: {self.best_acc:.2f}%")
        self.early_stop = True


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Define your neural network architecture
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# Split the dataset into train and validation sets
train_data, val_data = train_test_split(train_dataset, test_size=0.2, random_state=42)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

# Initialize model, loss function, optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize EarlyStopping object
early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=True, path="best_model.pth")

# Training loop
for epoch in range(20):  # Let's say we're training for 20 epochs
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data)
            val_loss += criterion(output, target).item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    val_loss /= len(val_loader)
    val_acc = 100 * correct / total
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%')

    # Check if early stopping criteria are met
    if early_stopping.should_stop(model, val_loss, val_acc):  # Pass the 'model' argument as well
        print("Early stopping triggered.")
        break


Epoch 1, Validation Loss: 0.2552, Validation Accuracy: 92.40%
Initialized best loss and accuracy.
Epoch 2, Validation Loss: 0.1828, Validation Accuracy: 94.69%
Improved validation loss.
Updated best loss and accuracy.
Epoch 3, Validation Loss: 0.1642, Validation Accuracy: 95.15%
Improved validation loss.
Updated best loss and accuracy.
Epoch 4, Validation Loss: 0.1271, Validation Accuracy: 96.03%
Improved validation loss.
Updated best loss and accuracy.
Epoch 5, Validation Loss: 0.1240, Validation Accuracy: 96.33%
Improved validation loss.
Updated best loss and accuracy.
Epoch 6, Validation Loss: 0.1087, Validation Accuracy: 96.67%
Improved validation loss.
Updated best loss and accuracy.
Epoch 7, Validation Loss: 0.1130, Validation Accuracy: 96.63%
No improvement in validation loss.
Epoch 8, Validation Loss: 0.1080, Validation Accuracy: 96.74%
No improvement in validation loss.
Epoch 9, Validation Loss: 0.1124, Validation Accuracy: 96.70%
No improvement in validation loss.
Epoch 10, V