In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt

# Callbacks

## Callbacks as GUI events

In [3]:
import ipywidgets as widgets

In [4]:
def f(o): print('hi')

From the ipywidget docs:
    the button is used to handle mouse clicks.The on_click() method of the Button can be used to register 
    function to be called when the button is clicked.

In [7]:
w = widgets.Button(description='Click me')

In [8]:
w

Button(description='Click me', style=ButtonStyle())

hi
hi


In [9]:
w.on_click(f)

## Creating your own callback: understanding callbacks

In [10]:
from time import sleep

In [13]:
# simulate an epoch calculation
def slow_calculation():
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
    return res

In [14]:
slow_calculation()

30

In [15]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
        if cb: cb(i) # call to callback
    return res

In [17]:
def show_progress(epoch):
    print('Awesome! We\'ve finished epoch ', epoch)

In [18]:
slow_calculation(show_progress)

Awesome! We've finished epoch  0
Awesome! We've finished epoch  1
Awesome! We've finished epoch  2
Awesome! We've finished epoch  3
Awesome! We've finished epoch  4


30

## Lambdas and partials

Defining a function at the moment we use it with **lambda** notations

In [21]:
slow_calculation(lambda o: print('Awesome! We\'ve finished epoch ', o))

Awesome! We've finished epoch  0
Awesome! We've finished epoch  1
Awesome! We've finished epoch  2
Awesome! We've finished epoch  3
Awesome! We've finished epoch  4


30

In [22]:
def show_progress(exclamation, epoch):
    print(exclamation, ' We\'ve finished epoch ', epoch)

In [23]:
from functools import partial

In [24]:
slow_calculation(partial(show_progress, 'Great'))

Great  We've finished epoch  0
Great  We've finished epoch  1
Great  We've finished epoch  2
Great  We've finished epoch  3
Great  We've finished epoch  4


30

In [25]:
f2 = partial(show_progress, 'Great')

## Callbacks as callable classes

In [34]:
class ProgressShowingCallback():
    def __init__(self, exclamation='Awesome'): self.exclamation = exclamation
    def __call__(self, epoch): print(self.exclamation, ' We\'ve finished eopch ', epoch)

In [35]:
# init
cb = ProgressShowingCallback('Wonderbar')

In [36]:
# call
cb('hi')

Wonderbar  We've finished eopch  hi


In [37]:
slow_calculation(cb)

Wonderbar  We've finished eopch  0
Wonderbar  We've finished eopch  1
Wonderbar  We've finished eopch  2
Wonderbar  We've finished eopch  3
Wonderbar  We've finished eopch  4


30

## Multiple callbacks funcs; *args and **kwargs

In [38]:
def f(*args, **kwargs): print('args: ', args, '; kwargs: ', kwargs)

In [40]:
# things passed as positional arguments ends up in a tuple (args)
# things passed as keyword arguments ens up as a dictionary (kwargs)
f(3, 'a', thing1='hello')

args:  (3, 'a') ; kwargs:  {'thing1': 'hello'}


In [43]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb: cb.after_calc(i, val=res)
    return res

In [46]:
# let's use this:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, *args, **kwargs): print('About to start')
    def after_calc (self, *args, **kwargs): print('Done step')

In [47]:
slow_calculation(PrintStepCallback())

About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step


30

In [56]:
# let's use this:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, *args, **kwargs): print('About to start epoch ', args[0])
    def after_calc (self, *args, **kwargs): print('Done step epoch', args[0], '; ' , kwargs)

In [57]:
slow_calculation(PrintStepCallback())

About to start epoch  0
Done step epoch 0 ;  {'val': 0}
About to start epoch  1
Done step epoch 1 ;  {'val': 1}
About to start epoch  2
Done step epoch 2 ;  {'val': 5}
About to start epoch  3
Done step epoch 3 ;  {'val': 14}
About to start epoch  4
Done step epoch 4 ;  {'val': 30}


30

## Modifying Behaviour

In [58]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb, 'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb, 'after_calc'):
            if cb.after_calc(i, res):
                print('stopping early')
                break
    return res

In [59]:
class PrintAfterCallback():
    def after_calc(self, epoch, val):
        print('After ', epoch, ' : ', val)
        if val > 10: return True

In [60]:
slow_calculation(PrintAfterCallback())

After  0  :  0
After  1  :  1
After  2  :  5
After  3  :  14
stopping early


14