In [None]:
#default_exp fastai_callback

### build our own callback

In [None]:
#export
# build our own callback
from time import sleep
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i * i
        sleep(1)
        if cb: cb(i)
    return res

In [None]:
def show_progress(epoch):
    print(f"Awesome, We've finished epoch {epoch}!")

slow_calculation(show_progress)

Awesome, We've finished epoch 0!
Awesome, We've finished epoch 1!
Awesome, We've finished epoch 2!
Awesome, We've finished epoch 3!
Awesome, We've finished epoch 4!


30

In [None]:
# it's call closure (you'll see a lot, especially you're a javascript programmer
def make_show_progress(exclamation):
    def _inner(epoch): print(f"{exclamation}!, We've finished epoch {epoch}!")
    return _inner

slow_calculation(make_show_progress("Nice!"))

Nice!!, We've finished epoch 0!
Nice!!, We've finished epoch 1!
Nice!!, We've finished epoch 2!
Nice!!, We've finished epoch 3!
Nice!!, We've finished epoch 4!


30

In [None]:
#export
def show_progress(exclamation,epoch):
    print(f"{exclamation}, We've finished epoch {epoch}!")
# std way to do closure
from functools import partial
slow_calculation(partial(show_progress, "Ok I guess"))

Ok I guess, We've finished epoch 0!
Ok I guess, We've finished epoch 1!
Ok I guess, We've finished epoch 2!
Ok I guess, We've finished epoch 3!
Ok I guess, We've finished epoch 4!


30

### Callback as callable classes

In [None]:
class ProgressShowingCallback():
    def __init__(self,exclamation="Awesome"): self.exclamation = exclamation
    def __call__(self,epoch): print(f"{self.exclamation}! We've just finished epoch {epoch}!")

In [None]:
cb = ProgressShowingCallback("Just super")

In [None]:
slow_calculation(cb)

Just super! We've just finished epoch 0!
Just super! We've just finished epoch 1!
Just super! We've just finished epoch 2!
Just super! We've just finished epoch 3!
Just super! We've just finished epoch 4!


30

### Multiple callback args: *args, **kwargs

In [None]:
def f(*args, **kwargs): print(f"args {args}; kwargs {kwargs}")

In [None]:
f(3, 'a', thing1='hello')

args (3, 'a'); kwargs {'thing1': 'hello'}


In [None]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb: cb.after_calc(i)
    return res

In [None]:
class PrintStepCallback():
    def __init__(self): pass
    # def before_calc(self) would break -> *args and **kwargs altho being messy, makes the class more resilient
    def before_calc(self, *args, **kwargs): print(f"About to start")
    def after_calc(self, *args, **kwargs): print(f"Done step")


In [None]:
slow_calculation(PrintStepCallback())

About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step


30

### modifying behavior

In [None]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb,'after_calc'): 
            if cb.after_calc(i,res):
                print("stopping early")
                break
    return res

In [None]:
class PrintAfterCallback:
    def after_calc(self,epoch, val):
        print(f"After {epoch}: {val}")
        if val > 10: return True

In [None]:
slow_calculation(PrintAfterCallback())

After 0: 0
After 1: 1
After 2: 5
After 3: 14
stopping early


14

In [None]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb, self.res = cb, 0

    def callback(self,cb_name, *args):
        if not self.cb: return
        cb = getattr(self.cb, cb_name, None)
        if cb: return cb(self,*args)
    
    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i*i
            sleep(1)
            if self.callback('after_calc', i):
                print("stopping early")
                break



In [None]:
class ModifyingCallback():
    def after_calc(self, calc, epoch):
        print(f"After {epoch}: {calc.res}")
        if calc.res > 10: return True
        if calc.res < 3: calc.res = calc.res * 2

In [None]:
calculator = SlowCalculator(ModifyingCallback())
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15