In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
BEST_MODEL_PATH = '/content/drive/MyDrive/MyModel/model_best.pt'
DRIVE_CHP_PATH = '/content/drive/MyDrive/MyModel/model_last_checkpoint.pth'

In [1]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

class Trainer():

    def __init__ (self, config, dataset, checkpoint=None):

        # ...
        # ...
        # ...

        if checkpoint == None:
            self.model = self.get_model(self.device)
            self.optimizer = torch.optim.AdamW(self.model.parameters(), config.lr, weight_decay = config.wd)
            self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.6, patience= 3, min_lr=1e-6, threshold=0.01)
            self.epoch, self.minloss= 0, float('inf')

        else:
            self.model, self.optimizer, self.scheduler, self.epoch, self.minloss = self.utils.load_checkpoint(checkpoint)
            print('\rCheckpoint loaded successfully'); print('-' * 50)

        # ...
        # ...
        # ...


    def save_checkpoint(self, model, optimizer, scheduler, epoch, minloss, save_path):
        checkpoint = {
            'minloss': minloss,
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer,
            'scheduler': scheduler}
        torch.save(checkpoint, save_path)



    def save_model(self, model, path, epoch):
        torch.save(model.state_dict(), path)
        self.training_result['Best Epoch'] = epoch



    def train (self, model, dataloader, optimizer, epoch, device):
        # ...
        # ...

        return true



    def evaluate (self, model, dataloader, device):
        # ...
        # ...

        return true



    def train_and_evaluate(self, n_epochs):

        loss_list =  []
        for epoch in range(self.epoch, n_epochs):

            # ...
            # ...

            train_loss = self.train(self.model, self.train_dataloader, self.optimizer, epoch, self.device)
            test_loss = self.evaluate(self.model, self.test_dataloader, self.device)

            # Save the best model based on minimum loss
            if test_loss < self.minloss:
                self.minloss = test_loss
                self.save_model(self.model, BEST_MODEL_PATH, epoch)

            # ...
            # ...

            self.save_checkpoint(self.model, self.optimizer, self.scheduler, epoch+1, self.minloss, DRIVE_CHP_PATH )

            # ...
            # ...


In [None]:
import torch
import os

def load_checkpoint(checkpoint):
    minloss = checkpoint['minloss']
    epoch = checkpoint['epoch']
    model = checkpoint['model']
    optimizer = checkpoint['optimizer']
    scheduler = checkpoint['scheduler']
    return model, optimizer, scheduler, epoch, minloss


if __name__ == '__main__':

    # ...
    # ...

    # Load fine-tuning checkpoint
    path = DRIVE_CHP_PATH
    checkpoint = torch.load(path) if os.path.isfile(path) else None

    # Initialize trainer
    trainer = Trainer(config, {'train_data':train_data, 'test_data':test_data}, checkpoint)

    # Train and evaluate the model
    trainer.train_and_evaluate(50)
