# Progress Bar

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import format_time
from early_stopping import *

In [3]:
class StatsLogging(Callback):
    '''Modified stats logging callback to log time stamps'''
    def __init__(self, metrics=[compute_accuracy]):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)
    
    def before_fit(self):
        metric_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]
        names = ['epoch'] + [f'train_{n}' for n in metric_names] + [
            f'valid_{n}' for n in metric_names] + ['time']
        self.logger(names)
        
    def before_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        self.start_time = time.time()
        
    def after_loss(self):
        stats = self.train_stats if self.model.training else self.valid_stats
        stats.accumulate(self.learner)
    
    def after_epoch(self):
        stats = [str(self.epoch)]
        for o in [self.train_stats, self.valid_stats]:
            stats += [f'{v:.6f}' for v in o.avg_stats]
        stats += [format_time(time.time() - self.start_time)]
        self.logger(stats)

In [4]:
# export 
class ProgressViewer(Callback):
    '''Callback utilizing FastAI frontend lib to display neat looking training progress''' 
    def before_fit(self):
        self.mbar = master_bar(range(self.num_epochs))
        self.mbar.on_iter_begin()
        self.learner.logger = partial(self.mbar.write, table=True)
        
    def after_fit(self): 
        self.mbar.on_iter_end()
        
    def after_batch(self):
        self.pb.update(self.iters_count)
    
    def set_pb(self, data_loader):
        self.pb = progress_bar(data_loader, parent=self.mbar)
        self.mbar.update(self.epoch)
        
    def before_epoch(self): 
        self.set_pb(self.data_bunch.train_dl)
        
    def before_valid(self): 
        self.set_pb(self.data_bunch.valid_dl)

# Tests

In [5]:
data_bunch = get_data_bunch(*get_mnist_data(), batch_size=64)
model = get_lin_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)
loss_fn = CrossEntropy()
callbacks = [ProgressViewer(), StatsLogging()]

In [6]:
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
    (DataLoader) 
        (Dataset) x: (50000, 784), y: (50000,)
        (Sampler) total: 50000, batch_size: 64, shuffle: True
    (DataLoader) 
        (Dataset) x: (10000, 784), y: (10000,)
        (Sampler) total: 10000, batch_size: 128, shuffle: False
(Model)
    Linear(784, 50)
    ReLU()
    Linear(50, 10)
(CrossEntropy)
(DynamicOpt) hyper_params: ['learning_rate']
(Callbacks) ['TrainEval', 'ProgressViewer', 'StatsLogging']


In [None]:
learner.fit(5)