In [2]:
# Libraries related with the neural net
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import save

# Utils libraries
from PIL import Image
import os
import pandas as pd
import time
import numpy as np
import datetime

In [3]:
# Class that manage the neural net
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=4):
        super(ResNet18Classifier, self).__init__()
        self.model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)  # Load pre-trained ResNet18
        in_features = self.model.fc.in_features  # Get the number of input features of the last layer
        self.model.fc = nn.Linear(in_features, num_classes)  # Replace the fc layer

    def forward(self, x):
        return self.model(x)

In [4]:
class EarlyStop:
    patience = None
    best_model = None
    threshold_loss = None
    verbose = None
    best_loss = None
    counter = None
    stop_training = None
    

    def __init__(self, patience=2, threshold_loss=0.001, verbose=False):
        self.patience = patience
        self.threshold_loss = threshold_loss        
        self.verbose = verbose
        
        self.best_loss = np.inf
        self.counter = 0
        self.stop_training = False
        
    def checking_conditions (self, epoch_result, model):
        # Check if the loss is lower than the best loss.
        if self.best_loss > epoch_result['Test Loss'] :
            self.best_loss = epoch_result['Test Loss']
            self.counter = 0
            self.best_model = self.save_model_temporaly (model)
            print ("EARLY STOP: The model is learning correctly. (Best Loss: {}).".format (self.best_loss))
        else:
            self.counter += 1
            print ("EARLY STOP: The model is not learning correctly (Counter: {}).".format (self.counter))

            # Check if the counter is greater or equal than the patience
            if self.counter >= self.patience:
                self.stop_training = True
                
        # Here assume that the extra information such as loss and model are saved upper
        if epoch_result['Test Loss'] < self.threshold_loss:
            self.stop_training = True
            
    def save_model_temporaly (self, model):
        # Help to remove the temporal model and avoid overwriting
        if os.path.exists (os.path.join (os.getcwd (), "models", "temp_model.pth")):
            os.remove (os.path.join (os.getcwd (), "models", "temp_model.pth"))
        
        # Save the model
        torch.save (model.state_dict (), os.path.join (os.getcwd (), "models", "temp_model.pth"))
        
    def restore_model (self):
        model = ResNet18Classifier ()
        return model.load_state_dict(torch.load(os.path.join(os.getcwd(), "models", "temp_model.pth")))

In [5]:
model = ResNet18Classifier ()
earlyStop = EarlyStop ()