In [84]:
#export
import os

from fastai.callbacks import EarlyStoppingCallback
from fastai.vision import Learner, partial, accuracy, nn
from fastai.utils.mod_display import progress_disabled_ctx

In [83]:
import os

from fastai.callbacks import EarlyStoppingCallback
from fastai.vision import Learner, partial, accuracy, untar_data, URLs, ImageList, nn
from fastai.utils.mod_display import progress_disabled_ctx
from include.anchored_graph import AnchorGraph
from include.layer import Add

In [54]:
#export
class ValueHolder():
    def __init__(self):
        self._value = 0
    
    def update_value(self, value):
        if self._value < value: self._value = value
    
    def reset(self):
        self._value = 0
    
    @property
    def value(self):
        return self._value

In [66]:
#export
class ValueTrackingCallback(EarlyStoppingCallback):
    
    def __init__(self, learn:Learner, value_holder:ValueHolder, monitor:str='accuracy', 
                 mode:str='auto', min_delta:int=0, patience:int=0):
        super().__init__(learn, monitor, mode, min_delta, patience)
        self.value_holder = value_holder
    
    def on_train_begin(self, **kwargs):
        self.value_holder.reset()
        super().on_train_begin(**kwargs)
    
    def on_train_end(self, **kwargs):
        super().on_train_end(**kwargs)
        self.value_holder.update_value(self.best)

In [73]:
#export
class Trainer():
    def __init__(self, path, data, loss_func=None, metrics=None, monitor='accuracy'):
        self.data = data
        self.loss_func = nn.CrossEntropyLoss() if loss_func == None else loss_func
        self.metrics = [accuracy] if metrics==None else metrics
        self.monitor = monitor
        
        self.accuracy = ValueHolder()
        self._model_num = 0
        
        if not os.path.exists(path):
            os.mkdir(path)
        self.path = os.path.join(os.getcwd(), path)
        
    def train(self, graph, max_epoch=100, min_delta=0, patience=0):
        model_num = self._model_num
        self._model_num = self._model_num + 1
        learn = Learner(self.data, graph.generate_model(), loss_func=self.loss_func, metrics=self.metrics,
                        callback_fns=[partial(ValueTrackingCallback,
                                              value_holder=self.accuracy, 
                                              monitor=self.monitor, 
                                              min_delta=min_delta, 
                                              patience=patience)])
        progress_disabled_ctx(learn)
        learn.fit(max_epoch)
        print(f'Saving model {model_num}...', end='')
        graph.save(os.path.join(self.path, str(model_num)))
        print(' Done!')
        print(f'Model number: {model_num}\nBest accuracy: {self.accuracy.value}')
        return model_num, self.accuracy.value.item()

In [85]:
!python nb2py.py model_trainer.ipynb

Converted model_trainer.ipynb to exp/nb_model_trainer.py


# Test

In [75]:
gr = AnchorGraph((3, 32, 32), (10,))
gr.deeper_net(Add)
gr.add_connection(gr.anchor[1], gr.anchor[2], layer_features=[64, 64, 128])
gr.visualize('aabc', './transform')

In [76]:
data_path = untar_data(URLs.CIFAR)
data = ImageList.from_folder(data_path).split_by_folder(train="train", valid="test").label_from_folder().databunch(bs=128)

In [77]:
tr = Trainer(os.path.join(os.getcwd(), 'trainer_test'), data, nn.CrossEntropyLoss())

In [79]:
tr.train(gr, patience=0)

epoch     train_loss  valid_loss  accuracy  time    
0         1.350756    1.268866    0.551800  00:07     
1         1.329440    1.679515    0.433400  00:06     
Epoch 1: early stopping
Saving model 1... Done!
Model number: 1
Best accuracy: 0.551800012588501


(1, 0.551800012588501)