In [4]:
import os
import time
import yaml
import torch
import pickle
import numpy as np

from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import DataLoader
from torch import nn

In [None]:
class Trainer():

    def __init__(self, device, save_folder,
                 train_dataset, valid_dataset, test1_dataset, test2_dataset, num_imgs=-1
                 model_params, num_workers=8, batch_size=4, seed=34, max_epochs=1, model_pretrain=''):
        super().__init__()

        self.device = torch.device(device)
        self.save_folder = save_folder

        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.test1_dataset = test1_dataset
        self.test2_dataset = test2_dataset
        self.num_imgs = num_imgs
        self.num_workers = num_workers

        self.seed = seed
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        
        self.model_params = model_params
        
        self.model = CRNN(self.model_params).to(self.device)
        if model_pretrain is not '':
            self.model.load_state_dict(torch.load(model_pretrain, map_location=self.device))
            print(f'Successfully loaded model from {model_pretrain}')
        self.loss = nn.CTCLoss()

        self.initialize_training()
        self.prepare_dirs()

    def initialize_training(self):
        self.patience = 0
        self.start_time = None
        self.score_best = None

        self.fix_seeds()
        
        train_dataset = ImageDataset(pickle_file=self.train_dataset,
                                     meta=False, num_imgs=self.num_imgs)
        self.train_iterator = DataLoader(dataset=train_dataset, batch_size=self.batch_size, 
                                         shuffle=True, num_workers=self.num_workers)
        
        valid_dataset = ImageDataset(pickle_file=self.valid_dataset,
                                     meta=False, num_imgs=self.num_imgs)
        self.valid_iterator = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, 
                                         shuffle=False, num_workers=self.num_workers)
        
        test1_dataset = ImageDataset(pickle_file=self.test1_dataset,
                                     meta=False, num_imgs=self.num_imgs)
        self.test1_iterator = DataLoader(dataset=test1_dataset, batch_size=self.batch_size, 
                                         shuffle=False, num_workers=self.num_workers)
        
        test2_dataset = ImageDataset(pickle_file=self.test2_dataset,
                                     meta=False, num_imgs=self.num_imgs)
        self.test2_iterator = DataLoader(dataset=test2_dataset, batch_size=self.batch_size, 
                                         shuffle=False, num_workers=self.num_workers)


    def prepare_dirs(self):
        if not os.path.exists(self.save_folder):
            os.makedirs(self.save_folder)

    def get_parameters(self):
        params = {
                    "device": str(self.device),
                    "save_folder": self.save_folder,
                    "train_dataset": self.train_dataset,
                    "valid_dataset": self.valid_dataset,
                    "test1_dataset": self.test1_dataset,
                    "test2_dataset": self.test2_dataset,
                    "num_imgs": self.num_imgs,
                    "num_workers": self.num_workers,
                    "seed": self.seed,
                    "batch_size": self.batch_size,
                    "model_params": self.model_params,
                    "max_epochs": self.max_epochs,
        }
        return params

    def fix_seeds(self):
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        
    def train_on_batches():
        self.model.train()
        while True:
            with tqdm(total=len(self.train_iterator)) as bar_train:
                for x, y_true in self.train_iterator:
                    '''TRAINING CODE HERE'''

                    '''TRAINING CODE HERE'''
                    bar_train.set_description(
                        f"Epoch: {self.epoch:3}. Current train loss: {current_loss:8.7}")
                    bar_train.update(1)
                    self.scheduler_obj.step()

            self.epoch += 1
            self.validate()
            if 0 < self.max_epochs <= self.epoch:
                break
        
        
    def train(self):
        try:
            self.train_on_batches()
        except KeyboardInterrupt:
            print("Stopped training")
        finally:
            self.save(save_policy='last')

    def save(self, save_policy='best'):
        print(f"Saving trainer to {self.save_folder}.")
        if len(self.save_folder) > 0 and not os.path.exists(self.save_folder):
            os.makedirs(self.save_folder)
            
        if save_policy == 'best':
            torch.save(self.model.state_dict(), os.path.join(self.save_folder, "model_state_dict"))
            torch.save(self.model, os.path.join(self.save_folder, "model"))
        elif save_policy == 'last':
            torch.save(self.model.state_dict(), os.path.join(self.save_folder, "model_last_state_dict"))
            torch.save(self.model, os.path.join(self.save_folder, "model_last"))

        torch.save({
            "parameters": self.get_parameters()
        }, os.path.join(self.save_folder, "trainer"))
        print("Trainer is saved.")

    @classmethod
    def load(cls, load_folder, device="cpu", load_policy='best'):
        checkpoint = torch.load(os.path.join(load_folder, "trainer"), map_location=device)
        parameters = checkpoint["parameters"]
        parameters.pop("device", None)
        trainer = cls(device=device, **parameters)
        if load_policy == 'best':
            trainer.model = torch.load(os.path.join(load_folder, "model"))
        elif load_policy == 'last':
            trainer.model = torch.load(os.path.join(load_folder, "model_last"))
        return trainer
