In [None]:
from itertools import cycle, islice
import os
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
import json
from torch.nn.utils.rnn import pad_sequence
from typing import List, Union, Dict
from collections import defaultdict

Create early stopping class and NN class

In [None]:
class EarlyStopping:
    """This code is from: , although adjusted to work for max accuracy as opposed to min loss"""
    
    """Early stops the training if validation topology accuracy doesn't improve after some given patience."""
    def __init__(self, patience = 7, verbose = False, delta = 0, path = 'checkpoint.pt', trace_func = print):
        """
        args:
            patience (int): how long to wait after last time improved validation accuracy.Default: 7
            verbose (bool): if True, prints a message for each validation accuracy improvement. Default: False
            delta (float): minimum change in monitored quantity to qualify as improvement. Default: 0
            path (str): path for checkpoint save location. Default: 'checkpoint.pt'
            trace_func (function): trace print function. Default: print            
        """
        
        self.patience = patience
        self.counter = 0
        
        self.path = path
        self.trace_func = trace_func
        self.verbose = verbose
        
        self.best_score = None
        self.delta = delta
        self.early_stop = False
        self.val_acc_max = np.Inf
        
    def __call__(self, val_acc, model):
        score = -val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        
        # changing the inequality such that the accuracy keeps getting better
        elif score >= self.best_score + self.delta:
            self.counter += 1

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        '''Saves model when validation accuracy increases.'''
        if self.verbose:
            self.trace_func(f'Validation accuracy increased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            
        torch.save(model.state_dict(), self.path)
        self.val_acc_max = val_acc