In [None]:
def train(model, optimizer, criterion, n_epochs, device, train_loader, valid_loader, lr_schedule=None,
          prehistory=None, checkpoint_file_path=None, checkpoint_note=""):
    if not isinstance(optimizer, optim.Optimizer):
        raise TypeError('expected an optimizer for `optimizer`, but {} was given'.format(type(optimizer)))
    if not isinstance(n_epochs, int):
        raise TypeError('expected an int for `n_epochs`, but {} was given'.format(type(n_epochs)))
    if lr_schedule:
        if not isinstance(lr_schedule, lr_scheduler._LRScheduler):
            raise TypeError('expected a _LRScheduler for `lr_schedule`, but {} was given'.format(type(lr_schedule)))
    
    if prehistory:
        history = prehistory
        best_valid_loss = min(history['loss']['valid'])
        train_loss = history['loss']['train'][-1]
        train_acc = history['acc']['train'][-1]
        valid_loss = history['loss']['valid'][-1]
        valid_acc = history['acc']['valid'][-1]
        i_epoch = history['epoch'][-1] + 1
    else:
        history = {'epoch':[], 'loss': {'train':[], 'valid':[]}, 'acc': {'train':[], 'valid':[]}}
        best_valid_loss = np.inf
        train_loss = np.inf
        train_acc = 0.
        valid_loss = np.inf
        valid_acc = 0.
        i_epoch = 0
    
    def update_progress_stats(update_epoch, update_train, update_valid):
        if update_epoch:
            if 'momentum' in optimizer.param_groups[0]:
                momentum = optimizer.param_groups[0]['momentum']
            elif 'betas' in optimizer.param_groups[0]:
                momentum = optimizer.param_groups[0]['betas'][0]
            else:
                momentum = None
            if len(best_dict['history']['epoch']) > 0:
                epoch_iterator.set_postfix_str("current_epoch={}, "
                                               "train(loss={:.4f}, acc={:.3f}), "
                                               "valid(loss={:.4f}, acc={:.3f}), "
                                               "best_valid(epoch={}, loss={:.4f}, acc={:.3f}), "
                                               "lr={:.4e}, momentum={:.4f}"
                                               .format(history['epoch'][-1] + 1,
                                                       train_loss, train_acc,
                                                       valid_loss, valid_acc,
                                                       best_dict['history']['epoch'][-1],
                                                       best_dict['history']['loss']['valid'][-1],
                                                       best_dict['history']['acc']['valid'][-1],
                                                       Decimal(optimizer.param_groups[0]['lr']),
                                                       momentum),
                                               refresh=True)
            else:
                epoch_iterator.set_postfix_str("current_epoch={}, "
                                               "train(loss={:.4f}, acc={:.3f}), "
                                               "valid(loss={:.4f}, acc={:.3f}), "
                                               "best_valid(epoch={}, loss={:.4f}, acc={:.3f}), "
                                               "lr={:.4e}, momentum={:.4f}"
                                               .format(0,
                                                       train_loss, train_acc,
                                                       valid_loss, valid_acc,
                                                       -1,
                                                       np.inf,
                                                       0.,
                                                       Decimal(optimizer.param_groups[0]['lr']),
                                                       momentum),
                                               refresh=True)
        if update_train:
            train_iterator.set_postfix_str("loss={}, acc={}".format(loss.item(), corrects / inputs.size(0)), refresh=True)
        if update_valid:
            valid_iterator.set_postfix_str("loss={}, acc={}".format(loss.item(), corrects / inputs.size(0)), refresh=True)
    
    tracking_dict = {'history': history,
                     'model_dict': model.state_dict(),
                     'optimizer_dict': optimizer.state_dict(),
                     'lr_dict': lr_schedule.state_dict() if lr_schedule else None,
                     'note': checkpoint_note}
    
    best_dict = deepcopy(tracking_dict)
    
    if n_epochs < 1:
        return history, best_dict
    
    model.to(device)
    
    plot_on = False
    
    if lr_schedule:
        last_epoch = int(lr_schedule.last_epoch)
    
    try:
        epoch_iterator = tqdm(iterable=range(i_epoch, i_epoch + n_epochs), desc="Train Epochs")
        for epoch in epoch_iterator:
            if lr_schedule:
                last_epoch += 1
            
            update_progress_stats(True, False, False)
            
            model.train()
            running_loss = 0
            n_corrects = 0
            n_instances = 0
            train_iterator = tqdm(iterable=train_loader, desc="Train Iterations", leave=False)
            for i, (inputs, labels) in enumerate(train_iterator):
                if lr_schedule:
                    lr_schedule.step(last_epoch + (i / len(train_loader)))
                
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += (loss * inputs.size(0)).item()
                corrects = outputs.argmax(dim=1).eq(labels).sum().item()
                n_corrects += corrects
                n_instances += inputs.size(0)
                
                update_progress_stats(True, True, False)
                
            if lr_schedule:
                lr_schedule.last_epoch = last_epoch
            train_loss = running_loss / n_instances
            train_acc = n_corrects / n_instances
            
            update_progress_stats(True, False, False)

            model.eval()
            running_loss = 0
            n_corrects = 0
            n_instances = 0
            valid_iterator = tqdm(iterable=valid_loader, desc="Valid Iterations", leave=False)
            with torch.no_grad():
                for inputs, labels in valid_iterator:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    running_loss += (loss * inputs.size(0)).item()
                    corrects = outputs.argmax(dim=1).eq(labels).sum().item()
                    n_corrects += corrects
                    n_instances += inputs.size(0)
                    update_progress_stats(True, False, True)
            valid_loss = running_loss / n_instances
            valid_acc = n_corrects / n_instances
            
            update_progress_stats(True, False, False)
            
            history['epoch'].append(epoch)
            history['loss']['train'].append(train_loss)
            history['acc']['train'].append(train_acc)
            history['loss']['valid'].append(valid_loss)
            history['acc']['valid'].append(valid_acc)
            if history['loss']['valid'][-1] < best_valid_loss:
                best_valid_loss = history['loss']['valid'][-1]
                tracking_dict = {'history': history,
                     'model_dict': model.state_dict(),
                     'optimizer_dict': optimizer.state_dict(),
                     'lr_dict': lr_schedule.state_dict() if lr_schedule else None,
                     'note': checkpoint_note}
                best_dict = deepcopy(tracking_dict)
                if checkpoint_file_path:
                    torch.save(best_dict, checkpoint_file_path)

            if epoch >= 1:
                if not plot_on:
                    %matplotlib notebook
                    fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(13, 24), facecolor=fig_bg_color)
                    axes[0].set_facecolor(plot_bg_color)
                    axes[0].grid(True)
                    axes[0].set_title("Training/Validation Loss", fontsize=fontsize)
                    axes[0].set_xlabel("Epoch", fontsize=fontsize)
                    axes[0].set_ylabel("Loss", fontsize=fontsize)
                    axes[0].plot([], [], color='blue', label='train loss')
                    axes[0].plot([], [], color='orange', label='valid loss')
                    axes[0].legend()
                    axes[1].set_facecolor(plot_bg_color)
                    axes[1].grid(True)
                    axes[1].set_title("Training/Validation Accuracy", fontsize=fontsize)
                    axes[1].set_xlabel("Epoch", fontsize=fontsize)
                    axes[1].set_ylabel("Accuracy", fontsize=fontsize)
                    axes[1].plot([], [], color='blue', label='train acc')
                    axes[1].plot([], [], color='orange', label='valid acc')
                    axes[1].legend()
                    fig.canvas.draw()
                    plot_on = True
                    
                axes[0].lines[0].set_xdata(history['epoch'])
                axes[0].lines[0].set_ydata(history['loss']['train'])
                axes[0].lines[1].set_xdata(history['epoch'])
                axes[0].lines[1].set_ydata(history['loss']['valid'])
                axes[1].lines[0].set_xdata(history['epoch'])
                axes[1].lines[0].set_ydata(history['acc']['train'])
                axes[1].lines[1].set_xdata(history['epoch'])
                axes[1].lines[1].set_ydata(history['acc']['valid'])

                axes[0].set_xlim(-0.05 * epoch, 1.05 * epoch)
                max_min_diff = (max(history['loss']['train'] + history['loss']['valid'])
                                - min(history['loss']['train'] + history['loss']['valid']))
                if max_min_diff > 0:
                    axes[0].set_ylim(min(history['loss']['train'] + history['loss']['valid']) - 0.05 * max_min_diff,
                                max(history['loss']['train'] + history['loss']['valid']) + 0.05 * max_min_diff)
                axes[1].set_xlim(-0.05 * epoch, 1.05 * epoch)
                max_min_diff = (max(history['acc']['train'] + history['acc']['valid'])
                                - min(history['acc']['train'] + history['acc']['valid']))
                if max_min_diff > 0:
                    axes[1].set_ylim(min(history['acc']['train'] + history['acc']['valid']) - 0.05 * max_min_diff,
                                max(history['acc']['train'] + history['acc']['valid']) + 0.05 * max_min_diff)

                axes[0].xaxis.set_major_locator(AutoLocator())
                axes[0].yaxis.set_major_locator(AutoLocator())
                axes[1].xaxis.set_major_locator(AutoLocator())
                axes[1].yaxis.set_major_locator(AutoLocator())

                xlim = axes[0].get_xlim()
                xticks = [tick for tick in axes[0].get_xticks() if xlim[0] <= tick <= xlim[1]]
                if history['loss']['train'].index(max(history['loss']['train'])) not in xticks:
                    xticks.append(history['loss']['train'].index(max(history['loss']['train'])))
                if history['loss']['train'].index(min(history['loss']['train'])) not in xticks:
                    xticks.append(history['loss']['train'].index(min(history['loss']['train'])))
                if history['loss']['valid'].index(max(history['loss']['valid'])) not in xticks:
                    xticks.append(history['loss']['valid'].index(max(history['loss']['valid'])))
                if history['loss']['valid'].index(best_valid_loss) not in xticks:
                    xticks.append(history['loss']['valid'].index(best_valid_loss))
                if epoch not in xticks:
                    xticks.append(epoch)
                axes[0].set_xticks(xticks)

                xlim = axes[1].get_xlim()
                xticks = [tick for tick in axes[1].get_xticks() if xlim[0] <= tick <= xlim[1]]
                if history['acc']['train'].index(max(history['acc']['train'])) not in xticks:
                    xticks.append(history['acc']['train'].index(max(history['acc']['train'])))
                if history['acc']['train'].index(min(history['acc']['train'])) not in xticks:
                    xticks.append(history['acc']['train'].index(min(history['acc']['train'])))
                if history['acc']['valid'].index(max(history['acc']['valid'])) not in xticks:
                    xticks.append(history['acc']['valid'].index(max(history['acc']['valid'])))
                if history['acc']['valid'].index(min(history['acc']['valid'])) not in xticks:
                    xticks.append(history['acc']['valid'].index(min(history['acc']['valid'])))
                if epoch not in xticks:
                    xticks.append(epoch)
                axes[1].set_xticks(xticks)

                ylim = axes[0].get_ylim()
                yticks = [tick for tick in axes[0].get_yticks() if ylim[0] <= tick <= ylim[1]]
                if max(history['loss']['train']) not in yticks:
                    yticks.append(max(history['loss']['train']))
                if min(history['loss']['train']) not in yticks:
                    yticks.append(min(history['loss']['train']))
                if max(history['loss']['valid']) not in yticks:
                    yticks.append(max(history['loss']['valid']))
                if best_valid_loss not in yticks:
                    yticks.append(best_valid_loss)
                if history['loss']['train'][-1] not in yticks:
                    yticks.append(history['loss']['train'][-1])
                if history['loss']['valid'][-1] not in yticks:
                    yticks.append(history['loss']['valid'][-1])
                axes[0].set_yticks(yticks)

                ylim = axes[1].get_ylim()
                yticks = [tick for tick in axes[1].get_yticks() if ylim[0] <= tick <= ylim[1]]
                if max(history['acc']['train']) not in yticks:
                    yticks.append(max(history['acc']['train']))
                if min(history['acc']['train']) not in yticks:
                    yticks.append(min(history['acc']['train']))
                if max(history['acc']['valid']) not in yticks:
                    yticks.append(max(history['acc']['valid']))
                if min(history['acc']['valid']) not in yticks:
                    yticks.append(min(history['acc']['valid']))
                if history['acc']['train'][-1] not in yticks:
                    yticks.append(history['acc']['train'][-1])
                if history['acc']['valid'][-1] not in yticks:
                    yticks.append(history['acc']['valid'][-1])
                axes[1].set_yticks(yticks)

                fig.canvas.draw()
                plt.pause(0.0001)
                
    except KeyboardInterrupt:
        if lr_schedule:
            lr_schedule.step(int(last_epoch))
    %matplotlib inline
    return history, best_dict
    